fix: 重构 ssh agent forwarding 转发代理并修复资源残留

- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥
- 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册
- 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接
- 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期
- 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留
- 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为
- 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
This commit is contained in:
2026-04-27 00:06:32 +08:00
parent 1625997d8f
commit ad7c8b0587
2 changed files with 427 additions and 68 deletions
+211 -56
View File
@@ -1,14 +1,16 @@
package starssh
import (
"bytes"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
type testCloser struct {
@@ -20,15 +22,90 @@ func (c *testCloser) Close() error {
return nil
}
type trackedConn struct {
net.Conn
closed atomic.Int32
}
func (c *trackedConn) Close() error {
c.closed.Add(1)
if c.Conn == nil {
return nil
}
return c.Conn.Close()
}
type testSSHChannel struct {
readFunc func([]byte) (int, error)
stderr bytes.Buffer
closed atomic.Int32
closeOnce sync.Once
closeCh chan struct{}
}
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
return &testSSHChannel{
readFunc: readFunc,
closeCh: make(chan struct{}),
}
}
func newBlockingTestSSHChannel() *testSSHChannel {
ch := newTestSSHChannel(nil)
ch.readFunc = func(p []byte) (int, error) {
<-ch.closeCh
return 0, io.EOF
}
return ch
}
func (c *testSSHChannel) Read(p []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
if c.readFunc != nil {
return c.readFunc(p)
}
return 0, io.EOF
}
func (c *testSSHChannel) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testSSHChannel) Close() error {
if c == nil {
return nil
}
c.closeOnce.Do(func() {
c.closed.Add(1)
close(c.closeCh)
})
return nil
}
func (c *testSSHChannel) CloseWrite() error {
return nil
}
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return false, nil
}
func (c *testSSHChannel) Stderr() io.ReadWriter {
return &c.stderr
}
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
@@ -50,26 +127,26 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
return &ssh.Session{}, nil
}
var agentInitCalls atomic.Int32
var probeCalls atomic.Int32
closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
agentInitCalls.Add(1)
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeCalls.Add(1)
if timeout != time.Second {
t.Fatalf("unexpected forwarding timeout: %v", timeout)
}
return sshagent.NewKeyring(), closer, nil
return nil
}
var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeCalls.Add(1)
if client != baseClient {
t.Fatalf("unexpected routed client %p", client)
}
if keyring == nil {
t.Fatal("expected non-nil forwarded agent keyring")
if timeout != time.Second {
t.Fatalf("unexpected routed timeout: %v", timeout)
}
return nil
return closer, nil
}
var requestCalls atomic.Int32
@@ -88,8 +165,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
t.Fatalf("second exec session: %v", err)
}
if agentInitCalls.Load() != 1 {
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
if probeCalls.Load() != 1 {
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
}
if routeCalls.Load() != 1 {
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
@@ -110,13 +187,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
oldNewSSHSession := newSSHSession
oldRequestSessionPTY := requestSessionPTY
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
requestSessionPTY = oldRequestSessionPTY
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
@@ -138,10 +215,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
return nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
@@ -162,11 +241,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
@@ -176,9 +255,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
t.Fatal("agent forwarder should not initialize when disabled")
return nil, nil, nil
probeSSHAgentForwarding = func(timeout time.Duration) error {
t.Fatal("agent forwarding probe should not run when disabled")
return nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("agent forwarding should not be requested when disabled")
@@ -191,18 +270,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
}
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
probeSSHAgentForwarding = func(timeout time.Duration) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("session request should not run when agent forwarder init fails")
@@ -216,16 +295,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
}
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
probeSSHAgentForwarding = func(timeout time.Duration) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
@@ -235,11 +314,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
}
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
@@ -247,10 +326,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
@@ -263,12 +344,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
@@ -283,10 +364,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
@@ -298,10 +381,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
@@ -314,8 +397,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
probeSSHAgentForwarding = func(timeout time.Duration) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
if _, err := star.NewExecSession(); err != nil {
@@ -325,10 +408,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
@@ -341,8 +424,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
probeSSHAgentForwarding = func(timeout time.Duration) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
}
if _, err := star.NewExecSession(); err != nil {
@@ -351,11 +434,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
}
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
@@ -370,13 +453,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
started := make(chan struct{})
release := make(chan struct{})
closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
close(started)
<-release
return sshagent.NewKeyring(), closer, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
return nil
return closer, nil
}
closeSSHClient = func(client sshClientRequester) error { return nil }
@@ -415,3 +498,75 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
t.Fatal("expected no leaked agent forwarder after close race")
}
}
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newTestSSHChannel(func(p []byte) (int, error) {
return 0, io.EOF
})
done := make(chan struct{})
go func() {
proxySSHAgentChannel(channel, tracked)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected local agent connection to be closed")
}
if channel.closed.Load() == 0 {
t.Fatal("expected ssh channel to be closed")
}
}
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newBlockingTestSSHChannel()
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: tracked,
}
if !proxy.registerBridge(bridge) {
t.Fatal("expected bridge registration to succeed")
}
done := make(chan struct{})
go func() {
bridge.run()
close(done)
}()
if err := proxy.Close(); err != nil {
t.Fatalf("close proxy: %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("bridge did not exit after proxy close")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected proxy close to close local agent connection")
}
if channel.closed.Load() == 0 {
t.Fatal("expected proxy close to close ssh channel")
}
}