diff --git a/agent_forward.go b/agent_forward.go index dc0069f..502efb2 100644 --- a/agent_forward.go +++ b/agent_forward.go @@ -4,7 +4,9 @@ import ( "errors" "fmt" "io" + "net" "strings" + "sync" "time" "golang.org/x/crypto/ssh" @@ -15,24 +17,53 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error { return sshagent.RequestAgentForwarding(session) } -var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { - return sshagent.ForwardToAgent(client, keyring) +const sshAgentChannelType = "auth-agent@openssh.com" + +var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + return startSSHAgentForwardProxy(client, timeout) } -var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { +var probeSSHAgentForwarding = func(timeout time.Duration) error { conn, err := dialSSHAgent(timeout) if err != nil { - return nil, nil, wrapSSHAgentForwardingUnavailable(err) + return wrapSSHAgentForwardingUnavailable(err) } if conn == nil { - return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) + return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) } - return sshagent.NewClient(conn), conn, nil + return conn.Close() } var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") +type sshAgentForwardProxy struct { + stopOnce sync.Once + stopCh chan struct{} + + activeMu sync.Mutex + active map[*sshAgentForwardBridge]struct{} +} + +func (p *sshAgentForwardProxy) Close() error { + if p == nil { + return nil + } + p.stopOnce.Do(func() { + close(p.stopCh) + }) + p.closeActive() + return nil +} + +type sshAgentForwardBridge struct { + proxy *sshAgentForwardProxy + channel ssh.Channel + conn net.Conn + + closeOnce sync.Once +} + func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { if s == nil { return errors.New("ssh client is nil") @@ -80,20 +111,21 @@ func (s *StarSSH) ensureAgentForwarding() error { return err } - keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo)) - if err != nil { + timeout := effectiveDialTimeout(s.LoginInfo) + if err := probeSSHAgentForwarding(timeout); err != nil { return wrapSSHAgentForwardingUnavailable(err) } if s.closing.Load() { - _ = closer.Close() return errSSHClientClosing } - if err := routeSSHAgentForwarding(client, keyring); err != nil { - _ = closer.Close() + closer, err := routeSSHAgentForwarding(client, timeout) + if err != nil { return err } if !s.canAttachAgentForwarder(client) { - _ = closer.Close() + if closer != nil { + _ = closer.Close() + } return errSSHClientClosing } s.agentForwarder = closer @@ -149,3 +181,175 @@ func wrapSSHAgentForwardingUnavailable(err error) error { } return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) } + +func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + if client == nil { + return nil, errors.New("ssh client is nil") + } + channels := client.HandleChannelOpen(sshAgentChannelType) + if channels == nil { + return nil, errors.New("agent: already have handler for " + sshAgentChannelType) + } + + proxy := &sshAgentForwardProxy{ + stopCh: make(chan struct{}), + active: make(map[*sshAgentForwardBridge]struct{}), + } + go func() { + for { + select { + case <-proxy.stopCh: + return + case ch, ok := <-channels: + if !ok { + return + } + go handleSSHAgentForwardChannel(proxy, ch, timeout) + } + } + }() + return proxy, nil +} + +func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) { + if ch == nil { + return + } + conn, err := dialSSHAgent(timeout) + if err != nil { + _ = ch.Reject(ssh.ConnectionFailed, err.Error()) + return + } + if conn == nil { + _ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable") + return + } + + channel, reqs, err := ch.Accept() + if err != nil { + _ = conn.Close() + return + } + go ssh.DiscardRequests(reqs) + + bridge := &sshAgentForwardBridge{ + proxy: proxy, + channel: channel, + conn: conn, + } + if !proxy.registerBridge(bridge) { + bridge.close() + return + } + go bridge.run() +} + +func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) { + bridge := &sshAgentForwardBridge{ + channel: channel, + conn: conn, + } + bridge.run() +} + +func (b *sshAgentForwardBridge) run() { + if b == nil { + return + } + defer b.unregister() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(b.channel, b.conn) + b.close() + }() + go func() { + defer wg.Done() + _, _ = io.Copy(b.conn, b.channel) + b.close() + }() + wg.Wait() +} + +func (b *sshAgentForwardBridge) close() { + if b == nil { + return + } + b.closeOnce.Do(func() { + closeWriter(b.channel) + closeWriter(b.conn) + if b.channel != nil { + _ = b.channel.Close() + } + if b.conn != nil { + _ = b.conn.Close() + } + }) +} + +func (b *sshAgentForwardBridge) unregister() { + if b == nil || b.proxy == nil { + return + } + b.proxy.unregisterBridge(b) +} + +func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool { + if p == nil || bridge == nil { + return false + } + + p.activeMu.Lock() + defer p.activeMu.Unlock() + + select { + case <-p.stopCh: + return false + default: + } + if p.active == nil { + p.active = make(map[*sshAgentForwardBridge]struct{}) + } + p.active[bridge] = struct{}{} + return true +} + +func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) { + if p == nil || bridge == nil { + return + } + + p.activeMu.Lock() + defer p.activeMu.Unlock() + + delete(p.active, bridge) +} + +func (p *sshAgentForwardProxy) closeActive() { + if p == nil { + return + } + + p.activeMu.Lock() + active := make([]*sshAgentForwardBridge, 0, len(p.active)) + for bridge := range p.active { + active = append(active, bridge) + } + p.active = make(map[*sshAgentForwardBridge]struct{}) + p.activeMu.Unlock() + + for _, bridge := range active { + bridge.close() + } +} + +func closeWriter(value any) { + type closeWriter interface { + CloseWrite() error + } + if cw, ok := value.(closeWriter); ok { + _ = cw.CloseWrite() + } +} diff --git a/agent_forward_test.go b/agent_forward_test.go index ae07de4..fdbf011 100644 --- a/agent_forward_test.go +++ b/agent_forward_test.go @@ -1,14 +1,16 @@ package starssh import ( + "bytes" "errors" "io" + "net" + "sync" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" - sshagent "golang.org/x/crypto/ssh/agent" ) type testCloser struct { @@ -20,15 +22,90 @@ func (c *testCloser) Close() error { return nil } +type trackedConn struct { + net.Conn + closed atomic.Int32 +} + +func (c *trackedConn) Close() error { + c.closed.Add(1) + if c.Conn == nil { + return nil + } + return c.Conn.Close() +} + +type testSSHChannel struct { + readFunc func([]byte) (int, error) + + stderr bytes.Buffer + closed atomic.Int32 + closeOnce sync.Once + closeCh chan struct{} +} + +func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel { + return &testSSHChannel{ + readFunc: readFunc, + closeCh: make(chan struct{}), + } +} + +func newBlockingTestSSHChannel() *testSSHChannel { + ch := newTestSSHChannel(nil) + ch.readFunc = func(p []byte) (int, error) { + <-ch.closeCh + return 0, io.EOF + } + return ch +} + +func (c *testSSHChannel) Read(p []byte) (int, error) { + if c == nil { + return 0, io.EOF + } + if c.readFunc != nil { + return c.readFunc(p) + } + return 0, io.EOF +} + +func (c *testSSHChannel) Write(p []byte) (int, error) { + return len(p), nil +} + +func (c *testSSHChannel) Close() error { + if c == nil { + return nil + } + c.closeOnce.Do(func() { + c.closed.Add(1) + close(c.closeCh) + }) + return nil +} + +func (c *testSSHChannel) CloseWrite() error { + return nil +} + +func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return false, nil +} + +func (c *testSSHChannel) Stderr() io.ReadWriter { + return &c.stderr +} + func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { oldNewSSHSession := newSSHSession - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { newSSHSession = oldNewSSHSession - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding closeSSHClient = oldCloseSSHClient @@ -50,26 +127,26 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { return &ssh.Session{}, nil } - var agentInitCalls atomic.Int32 + var probeCalls atomic.Int32 closer := &testCloser{} - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - agentInitCalls.Add(1) + probeSSHAgentForwarding = func(timeout time.Duration) error { + probeCalls.Add(1) if timeout != time.Second { t.Fatalf("unexpected forwarding timeout: %v", timeout) } - return sshagent.NewKeyring(), closer, nil + return nil } var routeCalls atomic.Int32 - routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { + routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { routeCalls.Add(1) if client != baseClient { t.Fatalf("unexpected routed client %p", client) } - if keyring == nil { - t.Fatal("expected non-nil forwarded agent keyring") + if timeout != time.Second { + t.Fatalf("unexpected routed timeout: %v", timeout) } - return nil + return closer, nil } var requestCalls atomic.Int32 @@ -88,8 +165,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { t.Fatalf("second exec session: %v", err) } - if agentInitCalls.Load() != 1 { - t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load()) + if probeCalls.Load() != 1 { + t.Fatalf("expected one agent probe, got %d", probeCalls.Load()) } if routeCalls.Load() != 1 { t.Fatalf("expected one agent route registration, got %d", routeCalls.Load()) @@ -110,13 +187,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { oldNewSSHSession := newSSHSession oldRequestSessionPTY := requestSessionPTY - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession requestSessionPTY = oldRequestSessionPTY - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) @@ -138,10 +215,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { return nil } - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return sshagent.NewKeyring(), &testCloser{}, nil + probeSSHAgentForwarding = func(timeout time.Duration) error { + return nil + } + routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + return &testCloser{}, nil } - routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } var requestCalls atomic.Int32 requestSSHAgentForwarding = func(session *ssh.Session) error { @@ -162,11 +241,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { oldNewSSHSession := newSSHSession - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) @@ -176,9 +255,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - t.Fatal("agent forwarder should not initialize when disabled") - return nil, nil, nil + probeSSHAgentForwarding = func(timeout time.Duration) error { + t.Fatal("agent forwarding probe should not run when disabled") + return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("agent forwarding should not be requested when disabled") @@ -191,18 +270,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { } func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") + probeSSHAgentForwarding = func(timeout time.Duration) error { + return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } requestSSHAgentForwarding = func(session *ssh.Session) error { t.Fatal("session request should not run when agent forwarder init fails") @@ -216,16 +295,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { } func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") + probeSSHAgentForwarding = func(timeout time.Duration) error { + return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") } err := star.RequestAgentForwarding(&ssh.Session{}) @@ -235,11 +314,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { } func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) @@ -247,10 +326,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return sshagent.NewKeyring(), &testCloser{}, nil + probeSSHAgentForwarding = func(timeout time.Duration) error { + return nil + } + routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + return &testCloser{}, nil } - routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } @@ -263,12 +344,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { oldNewSSHSession := newSSHSession - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding }) @@ -283,10 +364,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return sshagent.NewKeyring(), &testCloser{}, nil + probeSSHAgentForwarding = func(timeout time.Duration) error { + return nil + } + routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { + return &testCloser{}, nil } - routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } requestSSHAgentForwarding = func(session *ssh.Session) error { return errors.New("forwarding request denied") } @@ -298,10 +381,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { oldNewSSHSession := newSSHSession - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{ @@ -314,8 +397,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") + probeSSHAgentForwarding = func(timeout time.Duration) error { + return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") } if _, err := star.NewExecSession(); err != nil { @@ -325,10 +408,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { oldNewSSHSession := newSSHSession - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding t.Cleanup(func() { newSSHSession = oldNewSSHSession - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding }) star := &StarSSH{ @@ -341,8 +424,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { return &ssh.Session{}, nil } - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { - return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") + probeSSHAgentForwarding = func(timeout time.Duration) error { + return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") } if _, err := star.NewExecSession(); err != nil { @@ -351,11 +434,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { } func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { - oldNewSSHAgentForwarder := newSSHAgentForwarder + oldProbeSSHAgentForwarding := probeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldCloseSSHClient := closeSSHClient t.Cleanup(func() { - newSSHAgentForwarder = oldNewSSHAgentForwarder + probeSSHAgentForwarding = oldProbeSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding closeSSHClient = oldCloseSSHClient }) @@ -370,13 +453,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { started := make(chan struct{}) release := make(chan struct{}) closer := &testCloser{} - newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + probeSSHAgentForwarding = func(timeout time.Duration) error { + return nil + } + routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { close(started) <-release - return sshagent.NewKeyring(), closer, nil - } - routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { - return nil + return closer, nil } closeSSHClient = func(client sshClientRequester) error { return nil } @@ -415,3 +498,75 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { t.Fatal("expected no leaked agent forwarder after close race") } } + +func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) { + agentConn, peerConn := net.Pipe() + defer peerConn.Close() + + tracked := &trackedConn{Conn: agentConn} + channel := newTestSSHChannel(func(p []byte) (int, error) { + return 0, io.EOF + }) + + done := make(chan struct{}) + go func() { + proxySSHAgentChannel(channel, tracked) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("proxySSHAgentChannel did not exit after remote EOF") + } + + if tracked.closed.Load() == 0 { + t.Fatal("expected local agent connection to be closed") + } + if channel.closed.Load() == 0 { + t.Fatal("expected ssh channel to be closed") + } +} + +func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) { + agentConn, peerConn := net.Pipe() + defer peerConn.Close() + + tracked := &trackedConn{Conn: agentConn} + channel := newBlockingTestSSHChannel() + proxy := &sshAgentForwardProxy{ + stopCh: make(chan struct{}), + active: make(map[*sshAgentForwardBridge]struct{}), + } + bridge := &sshAgentForwardBridge{ + proxy: proxy, + channel: channel, + conn: tracked, + } + if !proxy.registerBridge(bridge) { + t.Fatal("expected bridge registration to succeed") + } + + done := make(chan struct{}) + go func() { + bridge.run() + close(done) + }() + + if err := proxy.Close(); err != nil { + t.Fatalf("close proxy: %v", err) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("bridge did not exit after proxy close") + } + + if tracked.closed.Load() == 0 { + t.Fatal("expected proxy close to close local agent connection") + } + if channel.closed.Load() == 0 { + t.Fatal("expected proxy close to close ssh channel") + } +}