package starssh import ( "errors" "io" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" sshagent "golang.org/x/crypto/ssh/agent" ) type testCloser struct { closed atomic.Int32 } func (c *testCloser) Close() error { c.closed.Add(1) return nil } func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { oldNewSSHSession := newSSHSession oldNewSSHAgentForwarder := newSSHAgentForwarder oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { newSSHSession = oldNewSSHSession newSSHAgentForwarder = oldNewSSHAgentForwarder routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, Timeout: time.Second, }, } star.setTransport(baseClient, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } return &ssh.Session{}, nil } var agentInitCalls atomic.Int32 closer := &testCloser{} newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { agentInitCalls.Add(1) if timeout != time.Second { t.Fatalf("unexpected forwarding timeout: %v", timeout) } return sshagent.NewKeyring(), closer, nil } var routeCalls atomic.Int32 routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { routeCalls.Add(1) if client != baseClient { t.Fatalf("unexpected routed client %p", client) } if keyring == nil { t.Fatal("expected non-nil forwarded agent keyring") } return nil } var requestCalls atomic.Int32 requestSSHAgentForwarding = func(session *ssh.Session) error { requestCalls.Add(1) if session == nil { t.Fatal("expected non-nil ssh session") } return nil } if _, err := star.NewExecSession(); err != nil { t.Fatalf("first exec session: %v", err) } if _, err := star.NewExecSession(); err != nil { t.Fatalf("second exec session: %v", err) } if agentInitCalls.Load() != 1 { t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load()) } if routeCalls.Load() != 1 { t.Fatalf("expected one agent route registration, got %d", routeCalls.Load()) } if requestCalls.Load() != 2 { t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load()) } closeSSHClient = func(client sshClientRequester) error { return nil } if err := star.Close(); err != nil { t.Fatalf("close starssh: %v", err) } if closer.closed.Load() != 1 { t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load()) } } func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { oldNewSSHSession := newSSHSession oldRequestSessionPTY := requestSessionPTY oldNewSSHAgentForwarder := newSSHAgentForwarder oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession requestSessionPTY = oldRequestSessionPTY newSSHAgentForwarder = oldNewSSHAgentForwarder routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } var ptyCalls atomic.Int32 requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error { ptyCalls.Add(1) return nil } newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return sshagent.NewKeyring(), &testCloser{}, nil } routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } var requestCalls atomic.Int32 requestSSHAgentForwarding = func(session *ssh.Session) error { requestCalls.Add(1) return nil } if _, err := star.NewPTYSession(nil); err != nil { t.Fatalf("new pty session: %v", err) } if ptyCalls.Load() != 1 { t.Fatalf("expected one PTY request, got %d", ptyCalls.Load()) } if requestCalls.Load() != 1 { t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load()) } } func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { oldNewSSHSession := newSSHSession oldNewSSHAgentForwarder := newSSHAgentForwarder oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession newSSHAgentForwarder = oldNewSSHAgentForwarder requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { t.Fatal("agent forwarder should not initialize when disabled") return nil, nil, nil } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("agent forwarding should not be requested when disabled") return nil } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session without forwarding: %v", err) } } func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { oldNewSSHAgentForwarder := newSSHAgentForwarder oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHAgentForwarder = oldNewSSHAgentForwarder requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("session request should not run when agent forwarder init fails") return nil } err := star.RequestAgentForwarding(&ssh.Session{}) if err == nil { t.Fatal("expected agent forwarding init error") } } func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { oldNewSSHAgentForwarder := newSSHAgentForwarder t.Cleanup(func() { newSSHAgentForwarder = oldNewSSHAgentForwarder }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") } err := star.RequestAgentForwarding(&ssh.Session{}) if !isSSHAgentForwardingUnavailableError(err) { t.Fatalf("expected unavailable error, got %v", err) } } func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { oldNewSSHAgentForwarder := newSSHAgentForwarder oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHAgentForwarder = oldNewSSHAgentForwarder routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return sshagent.NewKeyring(), &testCloser{}, nil } routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } err := star.RequestAgentForwarding(&ssh.Session{}) if !isSSHAgentForwardingDeniedError(err) { t.Fatalf("expected forwarding denied error, got %v", err) } } func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { oldNewSSHSession := newSSHSession oldNewSSHAgentForwarder := newSSHAgentForwarder oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession newSSHAgentForwarder = oldNewSSHAgentForwarder routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return sshagent.NewKeyring(), &testCloser{}, nil } routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore denied agent forwarding: %v", err) } } func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { oldNewSSHSession := newSSHSession oldNewSSHAgentForwarder := newSSHAgentForwarder t.Cleanup(func() { newSSHSession = oldNewSSHSession newSSHAgentForwarder = oldNewSSHAgentForwarder }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err) } } func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { oldNewSSHSession := newSSHSession oldNewSSHAgentForwarder := newSSHAgentForwarder t.Cleanup(func() { newSSHSession = oldNewSSHSession newSSHAgentForwarder = oldNewSSHAgentForwarder }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore agent setup error: %v", err) } } func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { oldNewSSHAgentForwarder := newSSHAgentForwarder oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { newSSHAgentForwarder = oldNewSSHAgentForwarder routeSSHAgentForwarding = oldRouteSSHAgentForwarding closeSSHClient = oldCloseSSHClient }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) started := make(chan struct{}) release := make(chan struct{}) closer := &testCloser{} newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { close(started) <-release return sshagent.NewKeyring(), closer, nil } routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } closeSSHClient = func(client sshClientRequester) error { return nil } errCh := make(chan error, 1) go func() { errCh <- star.ensureAgentForwarding() }() <-started closeDone := make(chan struct{}) go func() { _ = star.Close() close(closeDone) }() deadline := time.Now().Add(time.Second) for !star.closing.Load() { if time.Now().After(deadline) { t.Fatal("close did not enter closing state in time") } time.Sleep(time.Millisecond) } close(release) err := <-errCh if !errors.Is(err, errSSHClientClosing) { t.Fatalf("expected closing error, got %v", err) } <-closeDone if closer.closed.Load() != 1 { t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load()) } if got := star.takeAgentForwarder(); got != nil { t.Fatal("expected no leaked agent forwarder after close race") } }