package starssh import ( "errors" "fmt" "net" "os" "strings" "time" ) var ErrSSHAgentTimeout = errors.New("ssh-agent timeout") var dialResolvedSSHAgentFunc = dialResolvedSSHAgent type sshAgentDialOptions struct { Endpoint string Timeout time.Duration } type resolvedSSHAgentEndpoint struct { Endpoint string Source string Network string } type deadlineAgentConn struct { net.Conn timeout time.Duration } func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) { endpoint := strings.TrimSpace(options.Endpoint) if endpoint != "" { return resolvedSSHAgentEndpoint{ Endpoint: endpoint, Source: "identity-agent", Network: defaultSSHAgentNetwork(endpoint), }, nil } endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) if endpoint != "" { return resolvedSSHAgentEndpoint{ Endpoint: endpoint, Source: "SSH_AUTH_SOCK", Network: defaultSSHAgentNetwork(endpoint), }, nil } return defaultSSHAgentEndpoint() } func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) { resolved, err := resolveSSHAgentEndpoint(options) if err != nil { return nil, resolvedSSHAgentEndpoint{}, err } conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout) if isTimeoutError(err) { err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) } if err != nil { return nil, resolved, err } return conn, resolved, nil } func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) { options := sshAgentDialOptions{ Endpoint: timeouts.Endpoint, Timeout: timeouts.Dial, } started := time.Now() conn, resolved, err := dialSSHAgent(options) logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{ Step: step, Source: resolved.Source, Endpoint: resolved.Endpoint, Network: resolved.Network, Phase: "dial", Status: debugStatus(err), Duration: time.Since(started), Err: err, }) return conn, resolved, err } func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) { if debug == nil { return } debug(event) } func debugStatus(err error) string { if err != nil { return "error" } return "ok" } func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn { if conn == nil || timeout <= 0 { return conn } return &deadlineAgentConn{Conn: conn, timeout: timeout} } func (c *deadlineAgentConn) Read(p []byte) (int, error) { c.setDeadline() n, err := c.Conn.Read(p) return n, wrapSSHAgentConnError(err) } func (c *deadlineAgentConn) Write(p []byte) (int, error) { c.setDeadline() n, err := c.Conn.Write(p) return n, wrapSSHAgentConnError(err) } func (c *deadlineAgentConn) setDeadline() { if c == nil || c.timeout <= 0 || c.Conn == nil { return } _ = c.Conn.SetDeadline(time.Now().Add(c.timeout)) } func isTimeoutError(err error) bool { if err == nil { return false } if errors.Is(err, os.ErrDeadlineExceeded) { return true } var netErr net.Error return errors.As(err, &netErr) && netErr.Timeout() } func wrapSSHAgentConnError(err error) error { if isTimeoutError(err) { return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) } return err } func normalizeSSHAgentError(err error) error { if err == nil { return nil } if errors.Is(err, ErrSSHAgentTimeout) { return err } if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) { return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) } return err }