159 lines
3.4 KiB
Go
159 lines
3.4 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
|
||
|
|
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
|
||
|
|
|
||
|
|
type sshAgentDialOptions struct {
|
||
|
|
Endpoint string
|
||
|
|
Timeout time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
type resolvedSSHAgentEndpoint struct {
|
||
|
|
Endpoint string
|
||
|
|
Source string
|
||
|
|
Network string
|
||
|
|
}
|
||
|
|
|
||
|
|
type deadlineAgentConn struct {
|
||
|
|
net.Conn
|
||
|
|
timeout time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
|
||
|
|
endpoint := strings.TrimSpace(options.Endpoint)
|
||
|
|
if endpoint != "" {
|
||
|
|
return resolvedSSHAgentEndpoint{
|
||
|
|
Endpoint: endpoint,
|
||
|
|
Source: "identity-agent",
|
||
|
|
Network: defaultSSHAgentNetwork(endpoint),
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
|
||
|
|
if endpoint != "" {
|
||
|
|
return resolvedSSHAgentEndpoint{
|
||
|
|
Endpoint: endpoint,
|
||
|
|
Source: "SSH_AUTH_SOCK",
|
||
|
|
Network: defaultSSHAgentNetwork(endpoint),
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return defaultSSHAgentEndpoint()
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||
|
|
resolved, err := resolveSSHAgentEndpoint(options)
|
||
|
|
if err != nil {
|
||
|
|
return nil, resolvedSSHAgentEndpoint{}, err
|
||
|
|
}
|
||
|
|
|
||
|
|
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
|
||
|
|
if isTimeoutError(err) {
|
||
|
|
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, resolved, err
|
||
|
|
}
|
||
|
|
return conn, resolved, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
|
||
|
|
options := sshAgentDialOptions{
|
||
|
|
Endpoint: timeouts.Endpoint,
|
||
|
|
Timeout: timeouts.Dial,
|
||
|
|
}
|
||
|
|
started := time.Now()
|
||
|
|
conn, resolved, err := dialSSHAgent(options)
|
||
|
|
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
|
||
|
|
Step: step,
|
||
|
|
Source: resolved.Source,
|
||
|
|
Endpoint: resolved.Endpoint,
|
||
|
|
Network: resolved.Network,
|
||
|
|
Phase: "dial",
|
||
|
|
Status: debugStatus(err),
|
||
|
|
Duration: time.Since(started),
|
||
|
|
Err: err,
|
||
|
|
})
|
||
|
|
return conn, resolved, err
|
||
|
|
}
|
||
|
|
|
||
|
|
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
|
||
|
|
if debug == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
debug(event)
|
||
|
|
}
|
||
|
|
|
||
|
|
func debugStatus(err error) string {
|
||
|
|
if err != nil {
|
||
|
|
return "error"
|
||
|
|
}
|
||
|
|
return "ok"
|
||
|
|
}
|
||
|
|
|
||
|
|
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
|
||
|
|
if conn == nil || timeout <= 0 {
|
||
|
|
return conn
|
||
|
|
}
|
||
|
|
return &deadlineAgentConn{Conn: conn, timeout: timeout}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
|
||
|
|
c.setDeadline()
|
||
|
|
n, err := c.Conn.Read(p)
|
||
|
|
return n, wrapSSHAgentConnError(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
|
||
|
|
c.setDeadline()
|
||
|
|
n, err := c.Conn.Write(p)
|
||
|
|
return n, wrapSSHAgentConnError(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *deadlineAgentConn) setDeadline() {
|
||
|
|
if c == nil || c.timeout <= 0 || c.Conn == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
|
||
|
|
}
|
||
|
|
|
||
|
|
func isTimeoutError(err error) bool {
|
||
|
|
if err == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
var netErr net.Error
|
||
|
|
return errors.As(err, &netErr) && netErr.Timeout()
|
||
|
|
}
|
||
|
|
|
||
|
|
func wrapSSHAgentConnError(err error) error {
|
||
|
|
if isTimeoutError(err) {
|
||
|
|
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeSSHAgentError(err error) error {
|
||
|
|
if err == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if errors.Is(err, ErrSSHAgentTimeout) {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
|
||
|
|
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|