package starssh import ( "bytes" "errors" "io" "net" "sync" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" ) type testCloser struct { closed atomic.Int32 } func (c *testCloser) Close() error { c.closed.Add(1) return nil } type trackedConn struct { net.Conn closed atomic.Int32 } func (c *trackedConn) Close() error { c.closed.Add(1) if c.Conn == nil { return nil } return c.Conn.Close() } type testSSHChannel struct { readFunc func([]byte) (int, error) stderr bytes.Buffer closed atomic.Int32 closeOnce sync.Once closeCh chan struct{} } func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel { return &testSSHChannel{ readFunc: readFunc, closeCh: make(chan struct{}), } } func newBlockingTestSSHChannel() *testSSHChannel { ch := newTestSSHChannel(nil) ch.readFunc = func(p []byte) (int, error) { <-ch.closeCh return 0, io.EOF } return ch } func (c *testSSHChannel) Read(p []byte) (int, error) { if c == nil { return 0, io.EOF } if c.readFunc != nil { return c.readFunc(p) } return 0, io.EOF } func (c *testSSHChannel) Write(p []byte) (int, error) { return len(p), nil } func (c *testSSHChannel) Close() error { if c == nil { return nil } c.closeOnce.Do(func() { c.closed.Add(1) close(c.closeCh) }) return nil } func (c *testSSHChannel) CloseWrite() error { return nil } func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { return false, nil } func (c *testSSHChannel) Stderr() io.ReadWriter { return &c.stderr } func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { oldNewSSHSession := newSSHSession oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { newSSHSession = oldNewSSHSession probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, Timeout: time.Second, }, } star.setTransport(baseClient, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } return &ssh.Session{}, nil } var probeCalls atomic.Int32 closer := &testCloser{} probeSSHAgentForwarding = func(timeout time.Duration) error { probeCalls.Add(1) if timeout != time.Second { t.Fatalf("unexpected forwarding timeout: %v", timeout) } return nil } var routeCalls atomic.Int32 routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { routeCalls.Add(1) if client != baseClient { t.Fatalf("unexpected routed client %p", client) } if timeout != time.Second { t.Fatalf("unexpected routed timeout: %v", timeout) } return closer, nil } var requestCalls atomic.Int32 requestSSHAgentForwarding = func(session *ssh.Session) error { requestCalls.Add(1) if session == nil { t.Fatal("expected non-nil ssh session") } return nil } if _, err := star.NewExecSession(); err != nil { t.Fatalf("first exec session: %v", err) } if _, err := star.NewExecSession(); err != nil { t.Fatalf("second exec session: %v", err) } if probeCalls.Load() != 1 { t.Fatalf("expected one agent probe, got %d", probeCalls.Load()) } if routeCalls.Load() != 1 { t.Fatalf("expected one agent route registration, got %d", routeCalls.Load()) } if requestCalls.Load() != 2 { t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load()) } closeSSHClient = func(client sshClientRequester) error { return nil } if err := star.Close(); err != nil { t.Fatalf("close starssh: %v", err) } if closer.closed.Load() != 1 { t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load()) } } func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { oldNewSSHSession := newSSHSession oldRequestSessionPTY := requestSessionPTY oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession requestSessionPTY = oldRequestSessionPTY probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } var ptyCalls atomic.Int32 requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error { ptyCalls.Add(1) return nil } probeSSHAgentForwarding = func(timeout time.Duration) error { return nil } routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { return &testCloser{}, nil } var requestCalls atomic.Int32 requestSSHAgentForwarding = func(session *ssh.Session) error { requestCalls.Add(1) return nil } if _, err := star.NewPTYSession(nil); err != nil { t.Fatalf("new pty session: %v", err) } if ptyCalls.Load() != 1 { t.Fatalf("expected one PTY request, got %d", ptyCalls.Load()) } if requestCalls.Load() != 1 { t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load()) } } func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { oldNewSSHSession := newSSHSession oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession probeSSHAgentForwarding = oldProbeSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } probeSSHAgentForwarding = func(timeout time.Duration) error { t.Fatal("agent forwarding probe should not run when disabled") return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("agent forwarding should not be requested when disabled") return nil } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session without forwarding: %v", err) } } func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { probeSSHAgentForwarding = oldProbeSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) probeSSHAgentForwarding = func(timeout time.Duration) error { return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("session request should not run when agent forwarder init fails") return nil } err := star.RequestAgentForwarding(&ssh.Session{}) if err == nil { t.Fatal("expected agent forwarding init error") } } func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) probeSSHAgentForwarding = func(timeout time.Duration) error { return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") } err := star.RequestAgentForwarding(&ssh.Session{}) if !isSSHAgentForwardingUnavailableError(err) { t.Fatalf("expected unavailable error, got %v", err) } } func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) probeSSHAgentForwarding = func(timeout time.Duration) error { return nil } routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { return &testCloser{}, nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } err := star.RequestAgentForwarding(&ssh.Session{}) if !isSSHAgentForwardingDeniedError(err) { t.Fatalf("expected forwarding denied error, got %v", err) } } func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { oldNewSSHSession := newSSHSession oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } probeSSHAgentForwarding = func(timeout time.Duration) error { return nil } routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { return &testCloser{}, nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore denied agent forwarding: %v", err) } } func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { oldNewSSHSession := newSSHSession oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } probeSSHAgentForwarding = func(timeout time.Duration) error { return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err) } } func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { oldNewSSHSession := newSSHSession oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } probeSSHAgentForwarding = func(timeout time.Duration) error { return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") } if _, err := star.NewExecSession(); err != nil { t.Fatalf("new exec session should ignore agent setup error: %v", err) } } func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding closeSSHClient = oldCloseSSHClient }) star := &StarSSH{ LoginInfo: LoginInput{ ForwardSSHAgent: true, }, } star.setTransport(&ssh.Client{}, nil) started := make(chan struct{}) release := make(chan struct{}) closer := &testCloser{} probeSSHAgentForwarding = func(timeout time.Duration) error { return nil } routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { close(started) <-release return closer, nil } closeSSHClient = func(client sshClientRequester) error { return nil } errCh := make(chan error, 1) go func() { errCh <- star.ensureAgentForwarding() }() <-started closeDone := make(chan struct{}) go func() { _ = star.Close() close(closeDone) }() deadline := time.Now().Add(time.Second) for !star.closing.Load() { if time.Now().After(deadline) { t.Fatal("close did not enter closing state in time") } time.Sleep(time.Millisecond) } close(release) err := <-errCh if !errors.Is(err, errSSHClientClosing) { t.Fatalf("expected closing error, got %v", err) } <-closeDone if closer.closed.Load() != 1 { t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load()) } if got := star.takeAgentForwarder(); got != nil { t.Fatal("expected no leaked agent forwarder after close race") } } func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) { agentConn, peerConn := net.Pipe() defer peerConn.Close() tracked := &trackedConn{Conn: agentConn} channel := newTestSSHChannel(func(p []byte) (int, error) { return 0, io.EOF }) done := make(chan struct{}) go func() { proxySSHAgentChannel(channel, tracked) close(done) }() select { case <-done: case <-time.After(time.Second): t.Fatal("proxySSHAgentChannel did not exit after remote EOF") } if tracked.closed.Load() == 0 { t.Fatal("expected local agent connection to be closed") } if channel.closed.Load() == 0 { t.Fatal("expected ssh channel to be closed") } } func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) { agentConn, peerConn := net.Pipe() defer peerConn.Close() tracked := &trackedConn{Conn: agentConn} channel := newBlockingTestSSHChannel() proxy := &sshAgentForwardProxy{ stopCh: make(chan struct{}), active: make(map[*sshAgentForwardBridge]struct{}), } bridge := &sshAgentForwardBridge{ proxy: proxy, channel: channel, conn: tracked, } if !proxy.registerBridge(bridge) { t.Fatal("expected bridge registration to succeed") } done := make(chan struct{}) go func() { bridge.run() close(done) }() if err := proxy.Close(); err != nil { t.Fatalf("close proxy: %v", err) } select { case <-done: case <-time.After(time.Second): t.Fatal("bridge did not exit after proxy close") } if tracked.closed.Load() == 0 { t.Fatal("expected proxy close to close local agent connection") } if channel.closed.Load() == 0 { t.Fatal("expected proxy close to close ssh channel") } }