diff --git a/agent_forward.go b/agent_forward.go index 502efb2..f279557 100644 --- a/agent_forward.go +++ b/agent_forward.go @@ -19,12 +19,12 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error { const sshAgentChannelType = "auth-agent@openssh.com" -var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { - return startSSHAgentForwardProxy(client, timeout) +var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { + return startSSHAgentForwardProxy(client, timeouts) } -var probeSSHAgentForwarding = func(timeout time.Duration) error { - conn, err := dialSSHAgent(timeout) +var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { + conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts) if err != nil { return wrapSSHAgentForwardingUnavailable(err) } @@ -57,11 +57,15 @@ func (p *sshAgentForwardProxy) Close() error { } type sshAgentForwardBridge struct { - proxy *sshAgentForwardProxy - channel ssh.Channel - conn net.Conn + proxy *sshAgentForwardProxy + channel ssh.Channel + conn net.Conn + idleTimeout time.Duration - closeOnce sync.Once + closeOnce sync.Once + signalOnce sync.Once + done chan struct{} + activity chan struct{} } func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { @@ -111,14 +115,14 @@ func (s *StarSSH) ensureAgentForwarding() error { return err } - timeout := effectiveDialTimeout(s.LoginInfo) - if err := probeSSHAgentForwarding(timeout); err != nil { + timeouts := effectiveSSHAgentTimeouts(s.LoginInfo) + if err := probeSSHAgentForwarding(timeouts); err != nil { return wrapSSHAgentForwardingUnavailable(err) } if s.closing.Load() { return errSSHClientClosing } - closer, err := routeSSHAgentForwarding(client, timeout) + closer, err := routeSSHAgentForwarding(client, timeouts) if err != nil { return err } @@ -182,7 +186,7 @@ func wrapSSHAgentForwardingUnavailable(err error) error { return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) } -func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) { +func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { if client == nil { return nil, errors.New("ssh client is nil") } @@ -204,18 +208,18 @@ func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Cl if !ok { return } - go handleSSHAgentForwardChannel(proxy, ch, timeout) + go handleSSHAgentForwardChannel(proxy, ch, timeouts) } } }() return proxy, nil } -func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) { +func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) { if ch == nil { return } - conn, err := dialSSHAgent(timeout) + conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts) if err != nil { _ = ch.Reject(ssh.ConnectionFailed, err.Error()) return @@ -224,7 +228,6 @@ func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel _ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable") return } - channel, reqs, err := ch.Accept() if err != nil { _ = conn.Close() @@ -233,9 +236,10 @@ func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel go ssh.DiscardRequests(reqs) bridge := &sshAgentForwardBridge{ - proxy: proxy, - channel: channel, - conn: conn, + proxy: proxy, + channel: channel, + conn: conn, + idleTimeout: timeouts.Forward, } if !proxy.registerBridge(bridge) { bridge.close() @@ -256,18 +260,27 @@ func (b *sshAgentForwardBridge) run() { if b == nil { return } + b.ensureSignals() + stopWatchdog := b.startIdleWatchdog() + defer stopWatchdog() defer b.unregister() var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() - _, _ = io.Copy(b.channel, b.conn) + _, _ = io.Copy( + sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch}, + sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch}, + ) b.close() }() go func() { defer wg.Done() - _, _ = io.Copy(b.conn, b.channel) + _, _ = io.Copy( + sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch}, + sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch}, + ) b.close() }() wg.Wait() @@ -278,6 +291,8 @@ func (b *sshAgentForwardBridge) close() { return } b.closeOnce.Do(func() { + b.ensureSignals() + close(b.done) closeWriter(b.channel) closeWriter(b.conn) if b.channel != nil { @@ -289,6 +304,90 @@ func (b *sshAgentForwardBridge) close() { }) } +func (b *sshAgentForwardBridge) ensureSignals() { + if b == nil { + return + } + b.signalOnce.Do(func() { + b.done = make(chan struct{}) + b.activity = make(chan struct{}, 1) + }) +} + +func (b *sshAgentForwardBridge) startIdleWatchdog() func() { + if b == nil || b.idleTimeout <= 0 { + return func() {} + } + b.ensureSignals() + timer := time.NewTimer(b.idleTimeout) + stopped := make(chan struct{}) + go func() { + defer timer.Stop() + for { + select { + case <-timer.C: + b.close() + return + case <-b.activity: + resetTimer(timer, b.idleTimeout) + case <-b.done: + return + case <-stopped: + return + } + } + }() + return func() { + close(stopped) + } +} + +func (b *sshAgentForwardBridge) touch() { + if b == nil || b.idleTimeout <= 0 || b.activity == nil { + return + } + select { + case b.activity <- struct{}{}: + default: + } +} + +type sshAgentForwardActivityReader struct { + io.Reader + touch func() +} + +func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) { + n, err := r.Reader.Read(p) + if n > 0 && r.touch != nil { + r.touch() + } + return n, err +} + +type sshAgentForwardActivityWriter struct { + io.Writer + touch func() +} + +func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) { + n, err := w.Writer.Write(p) + if n > 0 && w.touch != nil { + w.touch() + } + return n, err +} + +func resetTimer(timer *time.Timer, timeout time.Duration) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(timeout) +} + func (b *sshAgentForwardBridge) unregister() { if b == nil || b.proxy == nil { return diff --git a/agent_forward_test.go b/agent_forward_test.go index fdbf011..5a61402 100644 --- a/agent_forward_test.go +++ b/agent_forward_test.go @@ -44,6 +44,32 @@ type testSSHChannel struct { closeCh chan struct{} } +type testNewChannel struct { + channel ssh.Channel + accepted atomic.Bool + rejected atomic.Bool +} + +func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { + c.accepted.Store(true) + requests := make(chan *ssh.Request) + close(requests) + return c.channel, requests, nil +} + +func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error { + c.rejected.Store(true) + return nil +} + +func (c *testNewChannel) ChannelType() string { + return sshAgentChannelType +} + +func (c *testNewChannel) ExtraData() []byte { + return nil +} + func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel { return &testSSHChannel{ readFunc: readFunc, @@ -114,8 +140,10 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { baseClient := &ssh.Client{} star := &StarSSH{ LoginInfo: LoginInput{ - ForwardSSHAgent: true, - Timeout: time.Second, + ForwardSSHAgent: true, + Timeout: time.Second, + SSHAgentTimeout: 3 * time.Second, + SSHAgentForwardTimeout: 4 * time.Second, }, } star.setTransport(baseClient, nil) @@ -129,22 +157,34 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { var probeCalls atomic.Int32 closer := &testCloser{} - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { probeCalls.Add(1) - if timeout != time.Second { - t.Fatalf("unexpected forwarding timeout: %v", timeout) + if timeouts.Dial != time.Second { + t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial) + } + if timeouts.Operation != 3*time.Second { + t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation) + } + if timeouts.Forward != 4*time.Second { + t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward) } return nil } var routeCalls atomic.Int32 - routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { routeCalls.Add(1) if client != baseClient { t.Fatalf("unexpected routed client %p", client) } - if timeout != time.Second { - t.Fatalf("unexpected routed timeout: %v", timeout) + if timeouts.Dial != time.Second { + t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial) + } + if timeouts.Operation != 3*time.Second { + t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation) + } + if timeouts.Forward != 4*time.Second { + t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward) } return closer, nil } @@ -215,10 +255,10 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { return nil } - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return nil } - routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { return &testCloser{}, nil } @@ -255,7 +295,7 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { t.Fatal("agent forwarding probe should not run when disabled") return nil } @@ -280,7 +320,7 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } requestSSHAgentForwarding = func(session *ssh.Session) error { @@ -303,7 +343,7 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") } @@ -326,10 +366,10 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return nil } - routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { return &testCloser{}, nil } requestSSHAgentForwarding = func(session *ssh.Session) error { @@ -364,10 +404,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return nil } - routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { return &testCloser{}, nil } requestSSHAgentForwarding = func(session *ssh.Session) error { @@ -397,7 +437,7 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } @@ -424,7 +464,7 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") } @@ -453,10 +493,10 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { started := make(chan struct{}) release := make(chan struct{}) closer := &testCloser{} - probeSSHAgentForwarding = func(timeout time.Duration) error { + probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { return nil } - routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { close(started) <-release return closer, nil @@ -570,3 +610,56 @@ func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) { t.Fatal("expected proxy close to close ssh channel") } } + +func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) { + oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc + t.Cleanup(func() { + dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent + }) + + agentConn, peerConn := net.Pipe() + defer peerConn.Close() + tracked := &trackedConn{Conn: agentConn} + dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) { + return tracked, nil + } + + channel := newBlockingTestSSHChannel() + newChannel := &testNewChannel{ + channel: channel, + } + proxy := &sshAgentForwardProxy{ + stopCh: make(chan struct{}), + active: make(map[*sshAgentForwardBridge]struct{}), + } + handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{ + Endpoint: "/tmp/agent.sock", + Forward: 20 * time.Millisecond, + }) + + if !newChannel.accepted.Load() { + t.Fatal("expected channel to be accepted") + } + + waitUntil(t, time.Second, func() bool { + return tracked.closed.Load() > 0 && channel.closed.Load() > 0 + }, "forwarded agent bridge did not close both sides after idle timeout") + + waitUntil(t, time.Second, func() bool { + proxy.activeMu.Lock() + defer proxy.activeMu.Unlock() + return len(proxy.active) == 0 + }, "forwarded agent bridge did not unregister after idle timeout") +} + +func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(time.Millisecond) + } + t.Fatal(message) +} diff --git a/login.go b/login.go index 19f739f..2cc4fe1 100644 --- a/login.go +++ b/login.go @@ -4,26 +4,14 @@ import ( "context" "encoding/base64" "errors" - "fmt" "net" "os" - "strings" "time" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" ) var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") -var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") -var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod - -var defaultAuthOrder = []AuthMethodKind{ - AuthMethodSSHAgent, - AuthMethodPrivateKey, - AuthMethodPassword, - AuthMethodKeyboardInteractive, -} func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil @@ -47,11 +35,35 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout) defer cancel() + order, err := normalizeAuthOrder(info.AuthOrder) + if err != nil { + return nil, err + } + + if shouldRetrySSHAgentAuth(info, order) { + agentAttempt := newSSHAgentAuthAttempt() + for { + agentAttempt.begin() + sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt) + if err == nil { + return sshInfo, nil + } + if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil { + continue + } + return sshInfo, err + } + } + + return loginOnceWithContext(loginCtx, info, authTimeout, nil) +} + +func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) { sshInfo := &StarSSH{ LoginInfo: info, } - auth, authCleanup, err := buildAuthMethods(info) + auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt) if err != nil { return nil, err } @@ -91,11 +103,11 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { } targetAddr := joinHostPort(info.Addr, info.Port) - rawConn, upstream, err := dialTargetConn(loginCtx, info) + rawConn, upstream, err := dialTargetConn(ctx, info) if err != nil { return sshInfo, err } - restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout) + restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout) defer restoreDeadline() clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) @@ -179,204 +191,3 @@ func effectiveDialTimeout(info LoginInput) time.Duration { return defaultLoginTimeout } } - -func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { - order, err := normalizeAuthOrder(info.AuthOrder) - if err != nil { - return nil, nil, err - } - - auth := make([]ssh.AuthMethod, 0, len(order)) - var agentErr error - var cleanupFuncs []func() - - for _, methodKind := range order { - switch methodKind { - case AuthMethodPrivateKey: - method, err := buildPrivateKeyAuthMethod(info) - if err != nil { - return nil, nil, err - } - if method != nil { - auth = append(auth, method) - } - case AuthMethodPassword: - method := buildPasswordAuthMethod(info.Password, info.PasswordCallback) - if method != nil { - auth = append(auth, method) - } - case AuthMethodKeyboardInteractive: - method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback) - if method != nil { - auth = append(auth, method) - } - case AuthMethodSSHAgent: - if info.DisableSSHAgent { - continue - } - agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info)) - if err != nil { - agentErr = err - continue - } - if agentMethod != nil { - auth = append(auth, agentMethod) - } - if cleanup != nil { - cleanupFuncs = append(cleanupFuncs, cleanup) - } - } - } - - if len(auth) == 0 { - if agentErr != nil { - return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr) - } - return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required") - } - - return auth, composeCleanup(cleanupFuncs...), nil -} - -func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) { - if len(order) == 0 { - return append([]AuthMethodKind(nil), defaultAuthOrder...), nil - } - - normalized := make([]AuthMethodKind, 0, len(order)) - seen := make(map[AuthMethodKind]struct{}, len(order)) - for _, raw := range order { - kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw)))) - if kind == "" { - return nil, errors.New("auth order contains an empty auth method") - } - if !isSupportedAuthMethodKind(kind) { - return nil, fmt.Errorf("unsupported auth method %q", raw) - } - if _, exists := seen[kind]; exists { - continue - } - seen[kind] = struct{}{} - normalized = append(normalized, kind) - } - - if len(normalized) == 0 { - return nil, errors.New("auth order is empty") - } - return normalized, nil -} - -func isSupportedAuthMethodKind(kind AuthMethodKind) bool { - switch kind { - case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent: - return true - default: - return false - } -} - -func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) { - if strings.TrimSpace(info.Prikey) == "" { - return nil, nil - } - - pemBytes := []byte(info.Prikey) - if info.PrikeyPwd == "" { - signer, err := ssh.ParsePrivateKey(pemBytes) - if err != nil { - return nil, err - } - return ssh.PublicKeys(signer), nil - } - - signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) - if err != nil { - return nil, err - } - return ssh.PublicKeys(signer), nil -} - -func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod { - if password != "" { - return ssh.Password(password) - } - if callback != nil { - return ssh.PasswordCallback(callback) - } - return nil -} - -func buildKeyboardInteractiveAuthMethod( - password string, - passwordCallback func() (string, error), - challenge ssh.KeyboardInteractiveChallenge, -) ssh.AuthMethod { - if challenge != nil { - return ssh.KeyboardInteractive(challenge) - } - if password == "" && passwordCallback == nil { - return nil - } - - keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) { - if len(questions) == 0 { - return []string{}, nil - } - - answer := password - if answer == "" { - var err error - answer, err = passwordCallback() - if err != nil { - return nil, err - } - } - - answers := make([]string, len(questions)) - for i := range questions { - answers[i] = answer - } - return answers, nil - } - return ssh.KeyboardInteractive(keyboardInteractiveChallenge) -} - -func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) { - conn, err := dialSSHAgent(timeout) - if err != nil { - if errors.Is(err, errSSHAgentUnavailable) { - return nil, nil, nil - } - return nil, nil, err - } - if conn == nil { - return nil, nil, nil - } - - signers, err := agent.NewClient(conn).Signers() - if err != nil { - _ = conn.Close() - return nil, nil, err - } - if len(signers) == 0 { - _ = conn.Close() - return nil, nil, errors.New("ssh-agent has no loaded keys") - } - - return ssh.PublicKeys(signers...), func() { - _ = conn.Close() - }, nil -} - -func composeCleanup(funcs ...func()) func() { - if len(funcs) == 0 { - return nil - } - return func() { - for i := len(funcs) - 1; i >= 0; i-- { - if funcs[i] != nil { - funcs[i]() - } - } - } -} diff --git a/login_timeout_test.go b/login_timeout_test.go index 3f407fa..9161b67 100644 --- a/login_timeout_test.go +++ b/login_timeout_test.go @@ -1,10 +1,19 @@ package starssh import ( + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "errors" + "io" + "net" + "os" + "sync" "testing" "time" "golang.org/x/crypto/ssh" + sshagent "golang.org/x/crypto/ssh/agent" ) func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) { @@ -18,6 +27,12 @@ func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) { if info.DialTimeout != 0 { t.Fatalf("DialTimeout=%v want 0", info.DialTimeout) } + if info.SSHAgentTimeout != 0 { + t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout) + } + if info.SSHAgentForwardTimeout != 0 { + t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout) + } } func TestEffectiveLoginTimeout(t *testing.T) { @@ -66,22 +81,71 @@ func TestEffectiveDialTimeout(t *testing.T) { } } -func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) { +func TestEffectiveSSHAgentTimeout(t *testing.T) { + tests := []struct { + name string + info LoginInput + want time.Duration + }{ + { + name: "default fallback without auth timeout", + info: LoginInput{}, + want: defaultSSHAgentTimeout, + }, + { + name: "auth timeout does not cap default", + info: LoginInput{Timeout: 9 * time.Second}, + want: defaultSSHAgentTimeout, + }, + { + name: "explicit agent timeout wins", + info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second}, + want: 90 * time.Second, + }, + { + name: "negative agent timeout disables operation deadline", + info: LoginInput{SSHAgentTimeout: -1}, + want: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := effectiveSSHAgentTimeout(tc.info); got != tc.want { + t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want) + } + }) + } +} + +func TestEffectiveSSHAgentForwardTimeout(t *testing.T) { + if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 { + t.Fatalf("zero forward timeout should stay zero, got %v", got) + } + if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second { + t.Fatalf("expected explicit forward timeout, got %v", got) + } +} + +func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) { oldBuilder := buildSSHAgentAuthMethodFunc t.Cleanup(func() { buildSSHAgentAuthMethodFunc = oldBuilder }) - captured := time.Duration(-2) - buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) { - captured = timeout + captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2} + buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { + captured = timeouts return ssh.Password("agent"), nil, nil } info := LoginInput{ - Timeout: 0, - DialTimeout: 11 * time.Second, - AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, + Timeout: 0, + DialTimeout: 11 * time.Second, + SSHAgentTimeout: 90 * time.Second, + SSHAgentForwardTimeout: 4 * time.Second, + IdentityAgent: "/tmp/custom-agent.sock", + AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, } auth, cleanup, err := buildAuthMethods(info) if err != nil { @@ -93,7 +157,607 @@ func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) { if len(auth) != 1 { t.Fatalf("expected one auth method, got %d", len(auth)) } - if captured != 11*time.Second { - t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second) + if captured.Dial != 11*time.Second { + t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second) + } + if captured.Operation != 90*time.Second { + t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second) + } + if captured.Forward != 4*time.Second { + t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second) + } + if captured.Endpoint != "/tmp/custom-agent.sock" { + t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint) } } + +func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) { + oldBuilder := buildSSHAgentAuthMethodFunc + t.Cleanup(func() { + buildSSHAgentAuthMethodFunc = oldBuilder + }) + + buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { + return ssh.Password("agent"), nil, nil + } + + auth, cleanup, err := buildAuthMethods(LoginInput{ + AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, + }) + if err != nil { + t.Fatalf("buildAuthMethods: %v", err) + } + if cleanup != nil { + cleanup() + } + if len(auth) != 1 { + t.Fatalf("auth methods=%d, want 1", len(auth)) + } +} + +func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) { + order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent} + if !shouldRetrySSHAgentAuth(LoginInput{}, order) { + t.Fatal("expected ssh-agent retry when ssh-agent is present after password") + } + if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) { + t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true") + } + if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) { + t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent") + } +} + +func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) { + oldBuilder := buildSSHAgentAuthMethodFunc + t.Cleanup(func() { + buildSSHAgentAuthMethodFunc = oldBuilder + }) + + buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { + if timeouts.SignFailure == nil { + t.Fatal("expected SignFailure callback for non-first ssh-agent auth") + } + if timeouts.SkipFingerprints != nil { + t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints) + } + return ssh.Password("agent"), nil, nil + } + + auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{ + Password: "secret", + AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}, + }, newSSHAgentAuthAttempt()) + if err != nil { + t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err) + } + if cleanup != nil { + cleanup() + } + if len(auth) != 2 { + t.Fatalf("auth methods=%d want 2", len(auth)) + } +} + +func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) { + attempt := newSSHAgentAuthAttempt() + attempt.skipFingerprint("SHA256:test") + if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) { + t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err) + } + attempt.begin() + if err := checkSSHAgentRetryPending(attempt); err != nil { + t.Fatalf("retry should reset on next attempt: %v", err) + } +} + +func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) { + signer := mustGenerateTestSigner(t) + attempt := newSSHAgentAuthAttempt() + callback := privateKeySignersCallback(signer, attempt) + + signers, err := callback() + if err != nil { + t.Fatalf("private key callback before retry: %v", err) + } + if len(signers) != 1 || signers[0] != signer { + t.Fatalf("private key callback returned %#v, want original signer", signers) + } + + attempt.skipFingerprint("SHA256:test") + signers, err = callback() + if !errors.Is(err, errRetrySSHAgentAuth) { + t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err) + } + if signers != nil { + t.Fatalf("private key callback signers=%#v want nil while retry pending", signers) + } + + attempt.begin() + signers, err = callback() + if err != nil { + t.Fatalf("private key callback after retry reset: %v", err) + } + if len(signers) != 1 || signers[0] != signer { + t.Fatalf("private key callback after retry returned %#v, want original signer", signers) + } +} + +func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) { + firstSigner := mustGenerateTestSigner(t) + secondSigner := mustGenerateTestSigner(t) + failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")} + + attempt := newSSHAgentAuthAttempt() + firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{ + SignFailure: attempt.recordSignFailure, + SkipFingerprints: attempt.skipSnapshot(), + }) + if len(firstMethods) != 2 { + t.Fatalf("first auth method signers=%d want 2", len(firstMethods)) + } + if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) { + t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err) + } + secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{ + SignFailure: attempt.recordSignFailure, + SkipFingerprints: attempt.skipSnapshot(), + }) + if len(secondMethods) != 1 { + t.Fatalf("second auth method signers=%d want 1", len(secondMethods)) + } + if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) { + t.Fatalf("second auth method did not skip failed first key") + } + signature, err := secondMethods[0].Sign(nil, []byte("challenge")) + if err != nil { + t.Fatalf("second signer Sign: %v", err) + } + if signature == nil { + t.Fatal("second signer returned nil signature") + } +} + +func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) { + firstSigner := mustGenerateTestSigner(t) + secondSigner := mustGenerateTestSigner(t) + wantErr := errors.New("first agent key cannot sign") + failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr} + + oldBuilder := buildSSHAgentAuthMethodFunc + t.Cleanup(func() { + buildSSHAgentAuthMethodFunc = oldBuilder + }) + + var buildCalls int + buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { + buildCalls++ + filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts) + if buildCalls == 1 { + if len(filteredSigners) != 2 { + t.Fatalf("first build signers=%d want 2", len(filteredSigners)) + } + return ssh.PublicKeys(filteredSigners...), nil, nil + } + if len(filteredSigners) != 1 { + t.Fatalf("retry build signers=%d want 1", len(filteredSigners)) + } + if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) { + t.Fatal("retry build did not skip failed signer") + } + return ssh.PublicKeys(filteredSigners...), nil, nil + } + + attempt := newSSHAgentAuthAttempt() + attempt.begin() + auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{ + AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, + }, attempt) + if err != nil { + t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err) + } + if cleanup != nil { + cleanup() + } + if len(auth) != 1 { + t.Fatalf("first auth methods=%d want 1", len(auth)) + } + if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) { + t.Fatalf("raw failing signer err=%v", err) + } + firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{ + SignFailure: attempt.recordSignFailure, + })[0] + if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) { + t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err) + } + + attempt.begin() + auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{ + AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, + }, attempt) + if err != nil { + t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err) + } + if cleanup != nil { + cleanup() + } + if len(auth) != 1 { + t.Fatalf("retry auth methods=%d want 1", len(auth)) + } + if buildCalls != 2 { + t.Fatalf("build calls=%d want 2", buildCalls) + } +} + +func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) { + plainSigner := mustGenerateTestSigner(t) + prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40") + + ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner}) + if len(ordered) != 2 { + t.Fatalf("ordered signers=%d want 2", len(ordered)) + } + if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) { + t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0])) + } +} + +func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) { + plainSigner := mustGenerateTestSigner(t) + cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673") + + ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner}) + if len(ordered) != 2 { + t.Fatalf("ordered signers=%d want 2", len(ordered)) + } + if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) { + t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0])) + } +} + +func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) { + firstSigner := mustGenerateTestSigner(t) + secondSigner := mustGenerateTestSigner(t) + + ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner}) + if len(ordered) != 2 { + t.Fatalf("ordered signers=%d want 2", len(ordered)) + } + if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) { + t.Fatalf("first signer changed order without hints") + } + if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) { + t.Fatalf("second signer changed order without hints") + } +} + +func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) { + signer := mustGenerateTestSigner(t) + wantErr := errors.New("agent refused operation") + var debugCalls int + wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{ + Resolved: resolvedSSHAgentEndpoint{ + Endpoint: "/tmp/debug-agent.sock", + Source: "identity-agent", + Network: "unix", + }, + Debug: func(event SSHAgentDebugEvent) { + debugCalls++ + if event.Step != "auth" || event.Phase != "sign" { + t.Fatalf("unexpected debug event: %+v", event) + } + if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" { + t.Fatalf("unexpected endpoint details: %+v", event) + } + if event.Status != "error" || !errors.Is(event.Err, wantErr) { + t.Fatalf("unexpected sign status: %+v", event) + } + }, + }) + + _, err := wrapped.Sign(rand.Reader, []byte("challenge")) + if !errors.Is(err, wantErr) { + t.Fatalf("Sign err=%v want original signer error", err) + } + if debugCalls != 1 { + t.Fatalf("debug calls=%d want 1", debugCalls) + } +} + +func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) { + signer := mustGenerateRSATestSigner(t) + spy := &testAlgorithmSpySigner{Signer: signer} + wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner) + if !ok { + t.Fatal("wrapped signer does not implement AlgorithmSigner") + } + + signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA) + if err != nil { + t.Fatalf("SignWithAlgorithm: %v", err) + } + if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 { + t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256) + } + if signature.Format != ssh.KeyAlgoRSASHA256 { + t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256) + } +} + +func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) { + signer := mustGenerateRSATestSigner(t) + restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA}) + if err != nil { + t.Fatalf("NewSignerWithAlgorithms: %v", err) + } + spy := &testMultiAlgorithmSpySigner{ + testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted}, + } + wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner) + if !ok { + t.Fatal("wrapped signer does not implement AlgorithmSigner") + } + + signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA) + if err != nil { + t.Fatalf("SignWithAlgorithm: %v", err) + } + if spy.lastAlgorithm != ssh.KeyAlgoRSA { + t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA) + } + if signature.Format != ssh.KeyAlgoRSA { + t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA) + } +} + +type deadlineSpyConn struct { + net.Conn + mu sync.Mutex + deadlines []time.Time + readErr error + writeErr error +} + +type testFailingSigner struct { + ssh.Signer + err error +} + +func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + return nil, s.err +} + +func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { + return nil, s.err +} + +type testAlgorithmSpySigner struct { + ssh.Signer + lastAlgorithm string +} + +func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { + s.lastAlgorithm = algorithm + return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm) +} + +type testMultiAlgorithmSpySigner struct { + *testAlgorithmSpySigner +} + +func (s *testMultiAlgorithmSpySigner) Algorithms() []string { + if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok { + return multiAlgorithmSigner.Algorithms() + } + return nil +} + +func mustGenerateTestSigner(t *testing.T) ssh.Signer { + t.Helper() + _, key, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate test private key: %v", err) + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("new test signer: %v", err) + } + return signer +} + +func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer { + t.Helper() + baseSigner := mustGenerateTestSigner(t) + publicKey := baseSigner.PublicKey() + return &commentedTestSigner{ + Signer: baseSigner, + publicKey: &sshagent.Key{ + Format: publicKey.Type(), + Blob: publicKey.Marshal(), + Comment: comment, + }, + } +} + +type commentedTestSigner struct { + ssh.Signer + publicKey ssh.PublicKey +} + +func (s *commentedTestSigner) PublicKey() ssh.PublicKey { + return s.publicKey +} + +func mustGenerateRSATestSigner(t *testing.T) ssh.Signer { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate rsa test private key: %v", err) + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatalf("new rsa test signer: %v", err) + } + return signer +} + +func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + c.deadlines = append(c.deadlines, deadline) + return nil +} + +func (c *deadlineSpyConn) deadlineCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.deadlines) +} + +func (c *deadlineSpyConn) firstDeadline() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.deadlines[0] +} + +func (c *deadlineSpyConn) Read(p []byte) (int, error) { + if c.readErr != nil { + return 0, c.readErr + } + return 0, nil +} + +func (c *deadlineSpyConn) Write(p []byte) (int, error) { + if c.writeErr != nil { + return 0, c.writeErr + } + return len(p), nil +} + +func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) { + spy := &deadlineSpyConn{readErr: io.EOF} + conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second) + buf := make([]byte, 1) + if _, err := conn.Read(buf); !errors.Is(err, io.EOF) { + t.Fatalf("Read err=%v", err) + } + if spy.deadlineCount() != 1 { + t.Fatalf("deadlines=%d want 1", spy.deadlineCount()) + } + if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 { + t.Fatalf("deadline=%v should be in the future", firstDeadline) + } +} + +func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) { + spy := &deadlineSpyConn{} + conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second) + if _, err := conn.Write([]byte("x")); err != nil { + t.Fatalf("Write err=%v", err) + } + if spy.deadlineCount() != 1 { + t.Fatalf("deadlines=%d want 1", spy.deadlineCount()) + } +} + +func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock") + resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "}) + if err != nil { + t.Fatalf("resolveSSHAgentEndpoint: %v", err) + } + if resolved.Endpoint != "/tmp/identity-agent.sock" { + t.Fatalf("endpoint=%q", resolved.Endpoint) + } + if resolved.Source != "identity-agent" { + t.Fatalf("source=%q", resolved.Source) + } +} + +func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock") + resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{}) + if err != nil { + t.Fatalf("resolveSSHAgentEndpoint: %v", err) + } + if resolved.Endpoint != "/tmp/env-agent.sock" { + t.Fatalf("endpoint=%q", resolved.Endpoint) + } + if resolved.Source != "SSH_AUTH_SOCK" { + t.Fatalf("source=%q", resolved.Source) + } +} + +func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + + oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc + t.Cleanup(func() { + dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent + }) + dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) { + return client, nil + } + + _, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{ + Operation: 20 * time.Millisecond, + Endpoint: "/tmp/hung-agent.sock", + }) + if cleanup != nil { + cleanup() + } + if !errors.Is(err, ErrSSHAgentTimeout) { + t.Fatalf("err=%v want ErrSSHAgentTimeout", err) + } +} + +func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) { + socketPath := tempUnixSocketPath(t) + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen unix: %v", err) + } + defer listener.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := listener.Accept() + if err != nil { + return + } + _ = conn.Close() + }() + + var events []SSHAgentDebugEvent + _, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{ + Dial: time.Second, + Operation: time.Second, + Endpoint: socketPath, + Debug: func(event SSHAgentDebugEvent) { + events = append(events, event) + }, + }) + <-done + + if len(events) == 0 { + t.Fatal("expected debug events") + } + if events[0].Step != "auth" || events[0].Phase != "dial" { + t.Fatalf("unexpected first event: %+v", events[0]) + } + if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" { + t.Fatalf("unexpected endpoint event: %+v", events[0]) + } +} + +func tempUnixSocketPath(t *testing.T) string { + t.Helper() + path := t.TempDir() + "/agent.sock" + t.Cleanup(func() { + _ = os.Remove(path) + }) + return path +} diff --git a/sshagent_auth.go b/sshagent_auth.go new file mode 100644 index 0000000..2ab189f --- /dev/null +++ b/sshagent_auth.go @@ -0,0 +1,668 @@ +package starssh + +import ( + "errors" + "fmt" + "io" + "sort" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" + sshagent "golang.org/x/crypto/ssh/agent" +) + +var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") +var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth") +var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod + +type sshAgentTimeouts struct { + Dial time.Duration + Operation time.Duration + Forward time.Duration + Endpoint string + Resolved resolvedSSHAgentEndpoint + Debug SSHAgentDebugFunc + SkipFingerprints map[string]struct{} + SignFailure func(ssh.PublicKey, error) +} + +type sshAgentAuthAttempt struct { + mu sync.Mutex + skipFingerprints map[string]struct{} + retryRequested bool +} + +var defaultAuthOrder = []AuthMethodKind{ + AuthMethodSSHAgent, + AuthMethodPrivateKey, + AuthMethodPassword, + AuthMethodKeyboardInteractive, +} + +func effectiveSSHAgentTimeout(info LoginInput) time.Duration { + switch { + case info.SSHAgentTimeout < 0: + return 0 + case info.SSHAgentTimeout > 0: + return info.SSHAgentTimeout + default: + return defaultSSHAgentTimeout + } +} + +func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts { + return sshAgentTimeouts{ + Dial: effectiveDialTimeout(info), + Operation: effectiveSSHAgentTimeout(info), + Forward: effectiveSSHAgentForwardTimeout(info), + Endpoint: info.IdentityAgent, + Debug: info.SSHAgentDebug, + } +} + +func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration { + if info.SSHAgentForwardTimeout > 0 { + return info.SSHAgentForwardTimeout + } + return 0 +} + +func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { + return buildAuthMethodsWithAgentAttempt(info, nil) +} + +func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) { + order, err := normalizeAuthOrder(info.AuthOrder) + if err != nil { + return nil, nil, err + } + + auth := make([]ssh.AuthMethod, 0, len(order)) + var agentErr error + var cleanupFuncs []func() + + for _, methodKind := range order { + switch methodKind { + case AuthMethodPrivateKey: + method, err := buildPrivateKeyAuthMethod(info, agentAttempt) + if err != nil { + return nil, nil, err + } + if method != nil { + auth = append(auth, method) + } + case AuthMethodPassword: + method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt) + if method != nil { + auth = append(auth, method) + } + case AuthMethodKeyboardInteractive: + method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt) + if method != nil { + auth = append(auth, method) + } + case AuthMethodSSHAgent: + if info.DisableSSHAgent { + continue + } + timeouts := effectiveSSHAgentTimeouts(info) + if agentAttempt != nil { + timeouts.SkipFingerprints = agentAttempt.skipSnapshot() + timeouts.SignFailure = agentAttempt.recordSignFailure + } + agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts) + if err != nil { + agentErr = err + continue + } + if agentMethod != nil { + auth = append(auth, agentMethod) + } + if cleanup != nil { + cleanupFuncs = append(cleanupFuncs, cleanup) + } + } + } + + if len(auth) == 0 { + if agentErr != nil { + return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr) + } + return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required") + } + + return auth, composeCleanup(cleanupFuncs...), nil +} + +func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) { + if len(order) == 0 { + return append([]AuthMethodKind(nil), defaultAuthOrder...), nil + } + + normalized := make([]AuthMethodKind, 0, len(order)) + seen := make(map[AuthMethodKind]struct{}, len(order)) + for _, raw := range order { + kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw)))) + if kind == "" { + return nil, errors.New("auth order contains an empty auth method") + } + if !isSupportedAuthMethodKind(kind) { + return nil, fmt.Errorf("unsupported auth method %q", raw) + } + if _, exists := seen[kind]; exists { + continue + } + seen[kind] = struct{}{} + normalized = append(normalized, kind) + } + + if len(normalized) == 0 { + return nil, errors.New("auth order is empty") + } + return normalized, nil +} + +func isSupportedAuthMethodKind(kind AuthMethodKind) bool { + switch kind { + case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent: + return true + default: + return false + } +} + +func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool { + if info.DisableSSHAgent { + return false + } + for _, methodKind := range order { + if methodKind == AuthMethodSSHAgent { + return true + } + } + return false +} + +func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) { + if strings.TrimSpace(info.Prikey) == "" { + return nil, nil + } + + pemBytes := []byte(info.Prikey) + if info.PrikeyPwd == "" { + signer, err := ssh.ParsePrivateKey(pemBytes) + if err != nil { + return nil, err + } + return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil + } + + signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) + if err != nil { + return nil, err + } + return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil +} + +func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) { + return func() ([]ssh.Signer, error) { + if err := checkSSHAgentRetryPending(agentAttempt); err != nil { + return nil, err + } + return []ssh.Signer{signer}, nil + } +} + +func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod { + if password == "" && callback == nil { + return nil + } + return ssh.PasswordCallback(func() (string, error) { + if err := checkSSHAgentRetryPending(agentAttempt); err != nil { + return "", err + } + if password != "" { + return password, nil + } + return callback() + }) +} + +func buildKeyboardInteractiveAuthMethod( + password string, + passwordCallback func() (string, error), + challenge ssh.KeyboardInteractiveChallenge, + agentAttempt *sshAgentAuthAttempt, +) ssh.AuthMethod { + if challenge != nil { + return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { + if err := checkSSHAgentRetryPending(agentAttempt); err != nil { + return nil, err + } + return challenge(user, instruction, questions, echos) + }) + } + if password == "" && passwordCallback == nil { + return nil + } + + keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) { + if err := checkSSHAgentRetryPending(agentAttempt); err != nil { + return nil, err + } + if len(questions) == 0 { + return []string{}, nil + } + + answer := password + if answer == "" { + var err error + answer, err = passwordCallback() + if err != nil { + return nil, err + } + } + + answers := make([]string, len(questions)) + for i := range questions { + answers[i] = answer + } + return answers, nil + } + return ssh.KeyboardInteractive(keyboardInteractiveChallenge) +} + +func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { + conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts) + if err != nil { + if errors.Is(err, errSSHAgentUnavailable) { + return nil, nil, nil + } + return nil, nil, err + } + if conn == nil { + return nil, nil, nil + } + conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation) + + started := time.Now() + signers, err := sshagent.NewClient(conn).Signers() + err = normalizeSSHAgentError(err) + logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{ + Step: "auth", + Source: resolved.Source, + Endpoint: resolved.Endpoint, + Network: resolved.Network, + Phase: "list", + Status: debugStatus(err), + Duration: time.Since(started), + KeyCount: len(signers), + Err: err, + }) + if err != nil { + _ = conn.Close() + return nil, nil, err + } + if len(signers) == 0 { + _ = conn.Close() + return nil, nil, errors.New("ssh-agent has no loaded keys") + } + + timeouts.Resolved = resolved + orderedSigners := orderSSHAgentSigners(signers) + filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts) + if len(filteredSigners) == 0 { + _ = conn.Close() + return nil, nil, errors.New("ssh-agent has no usable keys") + } + + return ssh.PublicKeys(filteredSigners...), func() { + _ = conn.Close() + }, nil +} + +func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer { + type orderedSigner struct { + signer ssh.Signer + index int + score int + comment string + } + + ordered := make([]orderedSigner, 0, len(signers)) + for index, signer := range signers { + if signer == nil || signer.PublicKey() == nil { + continue + } + ordered = append(ordered, orderedSigner{ + signer: signer, + index: index, + score: sshAgentSignerPriority(signer), + comment: sshAgentSignerComment(signer), + }) + } + + sort.SliceStable(ordered, func(i, j int) bool { + if ordered[i].score != ordered[j].score { + return ordered[i].score > ordered[j].score + } + return ordered[i].index < ordered[j].index + }) + + result := make([]ssh.Signer, 0, len(ordered)) + for _, item := range ordered { + result = append(result, item.signer) + } + return result +} + +func sshAgentSignerComment(signer ssh.Signer) string { + if signer == nil { + return "" + } + if key, ok := signer.PublicKey().(*sshagent.Key); ok { + return key.Comment + } + return "" +} + +func sshAgentSignerPriority(signer ssh.Signer) int { + comment := strings.TrimSpace(sshAgentSignerComment(signer)) + if comment == "" { + return 0 + } + + score := 0 + if priority, ok := parseSSHAgentSignerPriority(comment); ok { + score += 100000 + priority*1000 + } + + lower := strings.ToLower(comment) + if strings.Contains(lower, "current") { + score += 400 + } + if strings.Contains(lower, "cardno:") { + score += 300 + } + if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") { + score += 100 + } + if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") { + score += 50 + } + return score +} + +func parseSSHAgentSignerPriority(comment string) (int, bool) { + lower := strings.ToLower(comment) + index := strings.Index(lower, "priority=") + if index < 0 { + return 0, false + } + + value := strings.TrimSpace(comment[index+len("priority="):]) + if value == "" { + return 0, false + } + + end := 0 + for end < len(value) { + ch := value[end] + if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') { + end++ + continue + } + break + } + if end == 0 { + return 0, false + } + + priority, err := strconv.Atoi(value[:end]) + if err != nil { + return 0, false + } + return priority, true +} + +func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer { + filteredSigners := make([]ssh.Signer, 0, len(signers)) + for _, signer := range signers { + if signer == nil { + continue + } + publicKey := signer.PublicKey() + if publicKey == nil { + continue + } + if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip { + continue + } + if timeouts.SignFailure == nil && timeouts.Debug == nil { + filteredSigners = append(filteredSigners, signer) + continue + } + filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{ + Resolved: timeouts.Resolved, + Debug: timeouts.Debug, + SignFailure: timeouts.SignFailure, + })) + } + return filteredSigners +} + +func newSSHAgentAuthAttempt() *sshAgentAuthAttempt { + return &sshAgentAuthAttempt{ + skipFingerprints: make(map[string]struct{}), + } +} + +func (a *sshAgentAuthAttempt) begin() { + if a == nil { + return + } + a.mu.Lock() + defer a.mu.Unlock() + a.retryRequested = false +} + +func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} { + if a == nil { + return nil + } + a.mu.Lock() + defer a.mu.Unlock() + if len(a.skipFingerprints) == 0 { + return nil + } + snapshot := make(map[string]struct{}, len(a.skipFingerprints)) + for fingerprint := range a.skipFingerprints { + snapshot[fingerprint] = struct{}{} + } + return snapshot +} + +func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) { + _ = err + if a == nil || publicKey == nil { + return + } + a.skipFingerprint(ssh.FingerprintSHA256(publicKey)) +} + +func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) { + if a == nil { + return + } + a.mu.Lock() + defer a.mu.Unlock() + a.retryRequested = true + if fingerprint != "" { + a.skipFingerprints[fingerprint] = struct{}{} + } +} + +func (a *sshAgentAuthAttempt) shouldRetry() bool { + if a == nil { + return false + } + a.mu.Lock() + defer a.mu.Unlock() + return a.retryRequested +} + +func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error { + if agentAttempt != nil && agentAttempt.shouldRetry() { + return errRetrySSHAgentAuth + } + return nil +} + +type sshAgentRetrySigner struct { + signer ssh.Signer + publicKey ssh.PublicKey + options sshAgentSignerOptions +} + +type sshAgentRetryAlgorithmSigner struct { + sshAgentRetrySigner + algorithmSigner ssh.AlgorithmSigner +} + +type sshAgentRetryMultiAlgorithmSigner struct { + sshAgentRetryAlgorithmSigner + multiAlgorithmSigner ssh.MultiAlgorithmSigner +} + +type sshAgentSignerOptions struct { + Resolved resolvedSSHAgentEndpoint + Debug SSHAgentDebugFunc + SignFailure func(ssh.PublicKey, error) +} + +func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer { + return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure}) +} + +func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer { + publicKey := signer.PublicKey() + base := sshAgentRetrySigner{ + signer: signer, + publicKey: publicKey, + options: options, + } + if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok { + return &sshAgentRetryMultiAlgorithmSigner{ + sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{ + sshAgentRetrySigner: base, + algorithmSigner: multiAlgorithmSigner, + }, + multiAlgorithmSigner: multiAlgorithmSigner, + } + } + if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok { + return &sshAgentRetryAlgorithmSigner{ + sshAgentRetrySigner: base, + algorithmSigner: algorithmSigner, + } + } + return &base +} + +func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey { + return s.publicKey +} + +func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + started := time.Now() + signature, err := s.signer.Sign(rand, data) + return signature, s.finishSign(started, err) +} + +func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error { + err = normalizeSSHAgentError(err) + s.logSignDebug(started, err) + if err == nil { + return nil + } + if s.options.SignFailure != nil { + s.options.SignFailure(s.publicKey, err) + return wrapSSHAgentSignError(err) + } + return err +} + +func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) { + if s == nil || s.options.Debug == nil { + return + } + logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{ + Step: "auth", + Source: s.options.Resolved.Source, + Endpoint: s.options.Resolved.Endpoint, + Network: s.options.Resolved.Network, + Phase: "sign", + Status: debugStatus(err), + Duration: time.Since(started), + Err: err, + }) +} + +func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { + algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil) + started := time.Now() + signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) + return signature, s.finishSign(started, err) +} + +func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string { + return s.multiAlgorithmSigner.Algorithms() +} + +func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { + algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms()) + started := time.Now() + signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) + return signature, s.finishSign(started, err) +} + +func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string { + if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA { + return requested + } + if len(algorithms) == 0 { + return ssh.KeyAlgoRSASHA256 + } + for _, algorithm := range algorithms { + if algorithm == ssh.KeyAlgoRSA { + break + } + if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 { + return algorithm + } + } + return requested +} + +func wrapSSHAgentSignError(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err)) +} + +func composeCleanup(funcs ...func()) func() { + if len(funcs) == 0 { + return nil + } + return func() { + for i := len(funcs) - 1; i >= 0; i-- { + if funcs[i] != nil { + funcs[i]() + } + } + } +} diff --git a/sshagent_conn.go b/sshagent_conn.go new file mode 100644 index 0000000..2bd8a7e --- /dev/null +++ b/sshagent_conn.go @@ -0,0 +1,158 @@ +package starssh + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + "time" +) + +var ErrSSHAgentTimeout = errors.New("ssh-agent timeout") +var dialResolvedSSHAgentFunc = dialResolvedSSHAgent + +type sshAgentDialOptions struct { + Endpoint string + Timeout time.Duration +} + +type resolvedSSHAgentEndpoint struct { + Endpoint string + Source string + Network string +} + +type deadlineAgentConn struct { + net.Conn + timeout time.Duration +} + +func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) { + endpoint := strings.TrimSpace(options.Endpoint) + if endpoint != "" { + return resolvedSSHAgentEndpoint{ + Endpoint: endpoint, + Source: "identity-agent", + Network: defaultSSHAgentNetwork(endpoint), + }, nil + } + + endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) + if endpoint != "" { + return resolvedSSHAgentEndpoint{ + Endpoint: endpoint, + Source: "SSH_AUTH_SOCK", + Network: defaultSSHAgentNetwork(endpoint), + }, nil + } + + return defaultSSHAgentEndpoint() +} + +func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) { + resolved, err := resolveSSHAgentEndpoint(options) + if err != nil { + return nil, resolvedSSHAgentEndpoint{}, err + } + + conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout) + if isTimeoutError(err) { + err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) + } + if err != nil { + return nil, resolved, err + } + return conn, resolved, nil +} + +func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) { + options := sshAgentDialOptions{ + Endpoint: timeouts.Endpoint, + Timeout: timeouts.Dial, + } + started := time.Now() + conn, resolved, err := dialSSHAgent(options) + logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{ + Step: step, + Source: resolved.Source, + Endpoint: resolved.Endpoint, + Network: resolved.Network, + Phase: "dial", + Status: debugStatus(err), + Duration: time.Since(started), + Err: err, + }) + return conn, resolved, err +} + +func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) { + if debug == nil { + return + } + debug(event) +} + +func debugStatus(err error) string { + if err != nil { + return "error" + } + return "ok" +} + +func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn { + if conn == nil || timeout <= 0 { + return conn + } + return &deadlineAgentConn{Conn: conn, timeout: timeout} +} + +func (c *deadlineAgentConn) Read(p []byte) (int, error) { + c.setDeadline() + n, err := c.Conn.Read(p) + return n, wrapSSHAgentConnError(err) +} + +func (c *deadlineAgentConn) Write(p []byte) (int, error) { + c.setDeadline() + n, err := c.Conn.Write(p) + return n, wrapSSHAgentConnError(err) +} + +func (c *deadlineAgentConn) setDeadline() { + if c == nil || c.timeout <= 0 || c.Conn == nil { + return + } + _ = c.Conn.SetDeadline(time.Now().Add(c.timeout)) +} + +func isTimeoutError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, os.ErrDeadlineExceeded) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + +func wrapSSHAgentConnError(err error) error { + if isTimeoutError(err) { + return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) + } + return err +} + +func normalizeSSHAgentError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, ErrSSHAgentTimeout) { + return err + } + if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) { + return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err) + } + return err +} diff --git a/sshagent_unix.go b/sshagent_unix.go index e6c6fb6..1e98882 100644 --- a/sshagent_unix.go +++ b/sshagent_unix.go @@ -4,16 +4,19 @@ package starssh import ( "net" - "os" - "strings" "time" ) -func dialSSHAgent(timeout time.Duration) (net.Conn, error) { - agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) - if agentSock == "" { - return nil, errSSHAgentUnavailable - } +func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) { + return resolvedSSHAgentEndpoint{}, errSSHAgentUnavailable +} + +func defaultSSHAgentNetwork(endpoint string) string { + return "unix" +} + +func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) { + agentSock := resolved.Endpoint if timeout > 0 { return net.DialTimeout("unix", agentSock, timeout) } diff --git a/sshagent_windows.go b/sshagent_windows.go index f6b25b5..2de6be0 100644 --- a/sshagent_windows.go +++ b/sshagent_windows.go @@ -3,10 +3,16 @@ package starssh import ( + "bytes" "context" + "encoding/binary" "errors" + "fmt" + "io" "net" "os" + "path/filepath" + "strconv" "strings" "time" @@ -16,22 +22,40 @@ import ( const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent` -func dialSSHAgent(timeout time.Duration) (net.Conn, error) { - agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK")) - if agentSock != "" { - return dialWindowsSSHAgentEndpoint(agentSock, timeout) - } - return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true) +var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file") + +type gpgSocketInfo struct { + port uint16 + nonce []byte + cygwin bool } -func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) { - if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok { - return dialWindowsNamedPipe(pipePath, timeout, false) +func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) { + return resolvedSSHAgentEndpoint{ + Endpoint: defaultWindowsSSHAgentPipe, + Source: "platform-default", + Network: "windows-pipe", + }, nil +} + +func defaultSSHAgentNetwork(endpoint string) string { + if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok { + return "windows-pipe" } - if timeout > 0 { - return net.DialTimeout("unix", endpoint, timeout) + if isAgentSSHSocketPath(endpoint) { + return "gpg-socket" } - return net.Dial("unix", endpoint) + return "unix" +} + +func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) { + if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok { + return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default") + } + if isAgentSSHSocketPath(resolved.Endpoint) { + return dialWindowsGPGSocketFile(resolved.Endpoint, timeout) + } + return dialWindowsUnixAgent(resolved.Endpoint, timeout) } func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) { @@ -42,11 +66,7 @@ func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFo } defer cancel() - conn, err := winio.DialPipeContext(ctx, path) - if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) { - return nil, errSSHAgentUnavailable - } - return conn, err + return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound) } func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) { @@ -68,3 +88,184 @@ func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) { func isWindowsPipeUnavailable(err error) bool { return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND) } + +func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) { + if timeout > 0 { + return net.DialTimeout("unix", endpoint, timeout) + } + return net.Dial("unix", endpoint) +} + +func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) { + ctx := context.Background() + cancel := func() {} + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } + defer cancel() + return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0) +} + +func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) { + if path == "" { + return nil, fmt.Errorf("gpg agent endpoint is empty") + } + if depth > 8 { + return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path) + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + if target, ok := parseGPGAssuanSocketRedirect(data); ok { + target = resolveGPGSocketRedirectTarget(path, target) + if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok { + return dialWindowsNamedPipeContext(ctx, pipePath, false) + } + return dialWindowsGPGSocketFileDepth(ctx, target, depth+1) + } + + info, err := parseGPGSocketInfo(path, data) + if err != nil { + return nil, err + } + return dialWindowsGPGSocketInfo(ctx, info) +} + +func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) { + var dialer net.Dialer + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port)))) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + _ = conn.Close() + return nil, err + } + } + if _, err := conn.Write(info.nonce); err != nil { + _ = conn.Close() + return nil, err + } + if info.cygwin { + var nonce [16]byte + if _, err := io.ReadFull(conn, nonce[:]); err != nil { + _ = conn.Close() + return nil, err + } + var credential [8]byte + binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid())) + if _, err := conn.Write(credential[:]); err != nil { + _ = conn.Close() + return nil, err + } + if _, err := io.ReadFull(conn, credential[:]); err != nil { + _ = conn.Close() + return nil, err + } + } + _ = conn.SetDeadline(time.Time{}) + return conn, nil +} + +func resolveGPGSocketRedirectTarget(source string, target string) string { + target = strings.TrimSpace(target) + if target == "" || filepath.IsAbs(target) { + return target + } + if _, ok := normalizeWindowsSSHAgentPipe(target); ok { + return target + } + return filepath.Join(filepath.Dir(source), target) +} + +func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) { + if info, ok := parseGPGAssuanSocketInfo(data); ok { + return info, nil + } + if info, ok := parseGPGCygwinSocketInfo(data); ok { + return info, nil + } + return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path) +} + +func parseGPGAssuanSocketRedirect(data []byte) (string, bool) { + text := strings.ReplaceAll(string(data), "\r\n", "\n") + text = strings.TrimSuffix(text, "\n") + lines := strings.Split(text, "\n") + if len(lines) != 2 || lines[0] != "%Assuan%" { + return "", false + } + target, ok := strings.CutPrefix(lines[1], "socket=") + if !ok || strings.TrimSpace(target) == "" { + return "", false + } + return os.ExpandEnv(target), true +} + +func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) { + newline := bytes.IndexByte(data, '\n') + if newline <= 0 || len(data)-newline-1 != 16 { + return gpgSocketInfo{}, false + } + port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16) + if err != nil || port64 == 0 { + return gpgSocketInfo{}, false + } + nonce := make([]byte, 16) + copy(nonce, data[newline+1:]) + return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true +} + +func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) { + if !bytes.HasPrefix(data, []byte("!")) { + return gpgSocketInfo{}, false + } + fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00")) + if len(fields) != 3 || fields[1] != "s" { + return gpgSocketInfo{}, false + } + port64, err := strconv.ParseUint(fields[0], 10, 16) + if err != nil || port64 == 0 { + return gpgSocketInfo{}, false + } + hexParts := strings.Split(fields[2], "-") + if len(hexParts) != 4 { + return gpgSocketInfo{}, false + } + nonce := make([]byte, 0, 16) + for _, part := range hexParts { + if len(part) != 8 { + return gpgSocketInfo{}, false + } + value, err := strconv.ParseUint(part, 16, 32) + if err != nil { + return gpgSocketInfo{}, false + } + var chunk [4]byte + binary.LittleEndian.PutUint32(chunk[:], uint32(value)) + nonce = append(nonce, chunk[:]...) + } + return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true +} + +func isAgentSSHSocketPath(endpoint string) bool { + normalized := strings.ToLower(strings.TrimSpace(endpoint)) + return strings.HasSuffix(normalized, "s.gpg-agent.ssh") +} + +func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) { + if ctx == nil { + ctx = context.Background() + } + conn, err := winio.DialPipeContext(ctx, path) + if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) { + return nil, errSSHAgentUnavailable + } + if err != nil { + return nil, err + } + return conn, nil +} diff --git a/sshagent_windows_test.go b/sshagent_windows_test.go new file mode 100644 index 0000000..fbf661e --- /dev/null +++ b/sshagent_windows_test.go @@ -0,0 +1,152 @@ +//go:build windows + +package starssh + +import ( + "bytes" + "errors" + "io" + "net" + "os" + "path/filepath" + "strconv" + "testing" + "time" +) + +func TestParseGPGAssuanSocketInfo(t *testing.T) { + info, ok := parseGPGAssuanSocketInfo([]byte("7247\n0123456789abcdef")) + if !ok { + t.Fatal("expected Assuan socket info to parse") + } + if info.port != 7247 || string(info.nonce) != "0123456789abcdef" || info.cygwin { + t.Fatalf("info=%+v nonce=%x", info, info.nonce) + } +} + +func TestParseGPGCygwinSocketInfo(t *testing.T) { + info, ok := parseGPGCygwinSocketInfo([]byte("!7247 s 00000001-02030405-06070809-0a0b0c0d\x00")) + if !ok { + t.Fatal("expected Cygwin socket info to parse") + } + want := []byte{1, 0, 0, 0, 5, 4, 3, 2, 9, 8, 7, 6, 13, 12, 11, 10} + if info.port != 7247 || string(info.nonce) != string(want) || !info.cygwin { + t.Fatalf("info=%+v nonce=%x", info, info.nonce) + } +} + +func TestParseGPGAssuanSocketRedirect(t *testing.T) { + t.Setenv("STARSSH_TEST_PIPE", `\\.\pipe\openssh-ssh-agent`) + target, ok := parseGPGAssuanSocketRedirect([]byte("%Assuan%\r\nsocket=${STARSSH_TEST_PIPE}\r\n")) + if !ok { + t.Fatal("expected Assuan redirect to parse") + } + if target != `\\.\pipe\openssh-ssh-agent` { + t.Fatalf("target=%q", target) + } +} + +func TestReadInvalidAgentSSHSocketReturnsGPGSocketError(t *testing.T) { + path := t.TempDir() + "/S.gpg-agent.ssh" + if err := os.WriteFile(path, []byte("not a socket info file"), 0o600); err != nil { + t.Fatalf("write socket file: %v", err) + } + + _, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{ + Endpoint: path, + Source: "SSH_AUTH_SOCK", + Network: defaultSSHAgentNetwork(path), + }, 0) + if !errors.Is(err, errInvalidGPGSocketInfo) { + t.Fatalf("err=%v want errInvalidGPGSocketInfo", err) + } +} + +func TestMissingAgentSSHSocketReturnsReadError(t *testing.T) { + path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh") + + _, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{ + Endpoint: path, + Source: "identity-agent", + Network: defaultSSHAgentNetwork(path), + }, 0) + if err == nil { + t.Fatal("expected missing GPG socket file error") + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("err=%v want os.ErrNotExist", err) + } +} + +func TestUnreadableAgentSSHSocketReturnsReadError(t *testing.T) { + path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh") + if err := os.Mkdir(path, 0o700); err != nil { + t.Fatalf("mkdir socket path: %v", err) + } + + _, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{ + Endpoint: path, + Source: "identity-agent", + Network: defaultSSHAgentNetwork(path), + }, 0) + if err == nil { + t.Fatal("expected unreadable GPG socket file error") + } + if errors.Is(err, errInvalidGPGSocketInfo) { + t.Fatalf("err=%v should expose read failure before parse", err) + } +} + +func TestDialWindowsGPGSocketFilePerformsNonceHandshake(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp: %v", err) + } + defer listener.Close() + + type handshakeResult struct { + nonce []byte + err error + } + resultCh := make(chan handshakeResult, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + resultCh <- handshakeResult{err: err} + return + } + defer conn.Close() + + nonce := make([]byte, 16) + if _, err := io.ReadFull(conn, nonce); err != nil { + resultCh <- handshakeResult{err: err} + return + } + resultCh <- handshakeResult{nonce: append([]byte(nil), nonce...)} + }() + + socketPath := filepath.Join(t.TempDir(), "S.gpg-agent.ssh") + if err := os.WriteFile(socketPath, []byte(strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)+"\n0123456789abcdef"), 0o600); err != nil { + t.Fatalf("write socket file: %v", err) + } + + conn, err := dialWindowsGPGSocketFile(socketPath, time.Second) + if err != nil { + t.Fatalf("dialWindowsGPGSocketFile: %v", err) + } + _ = conn.Close() + + var result handshakeResult + select { + case result = <-resultCh: + case <-time.After(time.Second): + t.Fatal("listener did not accept GPG socket connection") + } + if result.err != nil { + t.Fatalf("listener handshake error: %v", result.err) + } + + if !bytes.Equal(result.nonce, []byte("0123456789abcdef")) { + t.Fatalf("nonce=%q", result.nonce) + } +} diff --git a/types.go b/types.go index acac5a1..ddbb4c4 100644 --- a/types.go +++ b/types.go @@ -16,6 +16,7 @@ import ( const ( defaultSSHPort = 22 defaultLoginTimeout = 5 * time.Second + defaultSSHAgentTimeout = 2 * time.Minute defaultKeepAliveTimeout = 3 * time.Second defaultShellPollInterval = 120 * time.Millisecond defaultShellSetupDelay = 200 * time.Millisecond @@ -58,6 +59,20 @@ const ( AuthMethodSSHAgent AuthMethodKind = "ssh_agent" ) +type SSHAgentDebugFunc func(SSHAgentDebugEvent) + +type SSHAgentDebugEvent struct { + Step string + Source string + Endpoint string + Network string + Phase string + Status string + Duration time.Duration + KeyCount int + Err error +} + type StarSSH struct { stateMu sync.RWMutex Client *ssh.Client @@ -92,15 +107,30 @@ type LoginInput struct { DisableSSHAgent bool ForwardSSHAgent bool AuthOrder []AuthMethodKind - Addr string - Port int + // IdentityAgent overrides the local ssh-agent endpoint used for authentication + // and agent forwarding. Empty uses SSH_AUTH_SOCK, or the platform default where + // one exists. + IdentityAgent string + Addr string + Port int // Timeout limits the SSH handshake/authentication phase after a TCP connection has // already been established. Zero means no authentication timeout. Timeout time.Duration // DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and // local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise // uses the package default dial timeout. Negative disables the default dial timeout. - DialTimeout time.Duration + DialTimeout time.Duration + // SSHAgentTimeout limits ssh-agent protocol operations such as listing keys and + // signing challenges. Zero uses the package default, and negative disables the + // per-operation deadline. This is intentionally separate from Timeout and + // DialTimeout because hardware-backed agents may require a PIN or touch confirmation. + SSHAgentTimeout time.Duration + // SSHAgentForwardTimeout limits idle reads and writes on forwarded agent + // channels. Zero or negative leaves forwarded channels without an idle deadline. + SSHAgentForwardTimeout time.Duration + // SSHAgentDebug receives structured ssh-agent dial/protocol events. It is nil by + // default and must not log private key material. + SSHAgentDebug SSHAgentDebugFunc DialContext DialContextFunc Proxy *ProxyConfig Jump *LoginInput