- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥 - 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册 - 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接 - 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期 - 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留 - 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为 - 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
573 lines
15 KiB
Go
573 lines
15 KiB
Go
package starssh
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type testCloser struct {
|
|
closed atomic.Int32
|
|
}
|
|
|
|
func (c *testCloser) Close() error {
|
|
c.closed.Add(1)
|
|
return nil
|
|
}
|
|
|
|
type trackedConn struct {
|
|
net.Conn
|
|
closed atomic.Int32
|
|
}
|
|
|
|
func (c *trackedConn) Close() error {
|
|
c.closed.Add(1)
|
|
if c.Conn == nil {
|
|
return nil
|
|
}
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
type testSSHChannel struct {
|
|
readFunc func([]byte) (int, error)
|
|
|
|
stderr bytes.Buffer
|
|
closed atomic.Int32
|
|
closeOnce sync.Once
|
|
closeCh chan struct{}
|
|
}
|
|
|
|
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
|
return &testSSHChannel{
|
|
readFunc: readFunc,
|
|
closeCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func newBlockingTestSSHChannel() *testSSHChannel {
|
|
ch := newTestSSHChannel(nil)
|
|
ch.readFunc = func(p []byte) (int, error) {
|
|
<-ch.closeCh
|
|
return 0, io.EOF
|
|
}
|
|
return ch
|
|
}
|
|
|
|
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
|
if c == nil {
|
|
return 0, io.EOF
|
|
}
|
|
if c.readFunc != nil {
|
|
return c.readFunc(p)
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (c *testSSHChannel) Close() error {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
c.closeOnce.Do(func() {
|
|
c.closed.Add(1)
|
|
close(c.closeCh)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (c *testSSHChannel) CloseWrite() error {
|
|
return nil
|
|
}
|
|
|
|
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
return false, nil
|
|
}
|
|
|
|
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
|
return &c.stderr
|
|
}
|
|
|
|
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
oldCloseSSHClient := closeSSHClient
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
closeSSHClient = oldCloseSSHClient
|
|
})
|
|
|
|
baseClient := &ssh.Client{}
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
Timeout: time.Second,
|
|
},
|
|
}
|
|
star.setTransport(baseClient, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
if client != baseClient {
|
|
t.Fatalf("unexpected ssh client %p", client)
|
|
}
|
|
return &ssh.Session{}, nil
|
|
}
|
|
|
|
var probeCalls atomic.Int32
|
|
closer := &testCloser{}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
probeCalls.Add(1)
|
|
if timeout != time.Second {
|
|
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var routeCalls atomic.Int32
|
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
|
routeCalls.Add(1)
|
|
if client != baseClient {
|
|
t.Fatalf("unexpected routed client %p", client)
|
|
}
|
|
if timeout != time.Second {
|
|
t.Fatalf("unexpected routed timeout: %v", timeout)
|
|
}
|
|
return closer, nil
|
|
}
|
|
|
|
var requestCalls atomic.Int32
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
requestCalls.Add(1)
|
|
if session == nil {
|
|
t.Fatal("expected non-nil ssh session")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("first exec session: %v", err)
|
|
}
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("second exec session: %v", err)
|
|
}
|
|
|
|
if probeCalls.Load() != 1 {
|
|
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
|
}
|
|
if routeCalls.Load() != 1 {
|
|
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
|
}
|
|
if requestCalls.Load() != 2 {
|
|
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
|
|
}
|
|
|
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
|
if err := star.Close(); err != nil {
|
|
t.Fatalf("close starssh: %v", err)
|
|
}
|
|
if closer.closed.Load() != 1 {
|
|
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
|
|
}
|
|
}
|
|
|
|
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldRequestSessionPTY := requestSessionPTY
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
requestSessionPTY = oldRequestSessionPTY
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
},
|
|
}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
return &ssh.Session{}, nil
|
|
}
|
|
|
|
var ptyCalls atomic.Int32
|
|
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
|
ptyCalls.Add(1)
|
|
return nil
|
|
}
|
|
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return nil
|
|
}
|
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
|
return &testCloser{}, nil
|
|
}
|
|
|
|
var requestCalls atomic.Int32
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
requestCalls.Add(1)
|
|
return nil
|
|
}
|
|
|
|
if _, err := star.NewPTYSession(nil); err != nil {
|
|
t.Fatalf("new pty session: %v", err)
|
|
}
|
|
if ptyCalls.Load() != 1 {
|
|
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
|
|
}
|
|
if requestCalls.Load() != 1 {
|
|
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
|
|
}
|
|
}
|
|
|
|
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
return &ssh.Session{}, nil
|
|
}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
t.Fatal("agent forwarding probe should not run when disabled")
|
|
return nil
|
|
}
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
t.Fatal("agent forwarding should not be requested when disabled")
|
|
return nil
|
|
}
|
|
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("new exec session without forwarding: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
|
}
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
t.Fatal("session request should not run when agent forwarder init fails")
|
|
return nil
|
|
}
|
|
|
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
|
if err == nil {
|
|
t.Fatal("expected agent forwarding init error")
|
|
}
|
|
}
|
|
|
|
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
|
}
|
|
|
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
|
if !isSSHAgentForwardingUnavailableError(err) {
|
|
t.Fatalf("expected unavailable error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return nil
|
|
}
|
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
|
return &testCloser{}, nil
|
|
}
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
return errors.New("forwarding request denied")
|
|
}
|
|
|
|
err := star.RequestAgentForwarding(&ssh.Session{})
|
|
if !isSSHAgentForwardingDeniedError(err) {
|
|
t.Fatalf("expected forwarding denied error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
|
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
|
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
},
|
|
}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
return &ssh.Session{}, nil
|
|
}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return nil
|
|
}
|
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
|
return &testCloser{}, nil
|
|
}
|
|
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
|
return errors.New("forwarding request denied")
|
|
}
|
|
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
},
|
|
}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
return &ssh.Session{}, nil
|
|
}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
|
}
|
|
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
|
oldNewSSHSession := newSSHSession
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
t.Cleanup(func() {
|
|
newSSHSession = oldNewSSHSession
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
})
|
|
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
},
|
|
}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
|
return &ssh.Session{}, nil
|
|
}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
|
}
|
|
|
|
if _, err := star.NewExecSession(); err != nil {
|
|
t.Fatalf("new exec session should ignore agent setup error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
|
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
|
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
|
oldCloseSSHClient := closeSSHClient
|
|
t.Cleanup(func() {
|
|
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
|
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
|
closeSSHClient = oldCloseSSHClient
|
|
})
|
|
|
|
star := &StarSSH{
|
|
LoginInfo: LoginInput{
|
|
ForwardSSHAgent: true,
|
|
},
|
|
}
|
|
star.setTransport(&ssh.Client{}, nil)
|
|
|
|
started := make(chan struct{})
|
|
release := make(chan struct{})
|
|
closer := &testCloser{}
|
|
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
|
return nil
|
|
}
|
|
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
|
close(started)
|
|
<-release
|
|
return closer, nil
|
|
}
|
|
closeSSHClient = func(client sshClientRequester) error { return nil }
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- star.ensureAgentForwarding()
|
|
}()
|
|
|
|
<-started
|
|
closeDone := make(chan struct{})
|
|
go func() {
|
|
_ = star.Close()
|
|
close(closeDone)
|
|
}()
|
|
|
|
deadline := time.Now().Add(time.Second)
|
|
for !star.closing.Load() {
|
|
if time.Now().After(deadline) {
|
|
t.Fatal("close did not enter closing state in time")
|
|
}
|
|
time.Sleep(time.Millisecond)
|
|
}
|
|
|
|
close(release)
|
|
|
|
err := <-errCh
|
|
if !errors.Is(err, errSSHClientClosing) {
|
|
t.Fatalf("expected closing error, got %v", err)
|
|
}
|
|
<-closeDone
|
|
|
|
if closer.closed.Load() != 1 {
|
|
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
|
|
}
|
|
if got := star.takeAgentForwarder(); got != nil {
|
|
t.Fatal("expected no leaked agent forwarder after close race")
|
|
}
|
|
}
|
|
|
|
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
|
agentConn, peerConn := net.Pipe()
|
|
defer peerConn.Close()
|
|
|
|
tracked := &trackedConn{Conn: agentConn}
|
|
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
|
return 0, io.EOF
|
|
})
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
proxySSHAgentChannel(channel, tracked)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
|
}
|
|
|
|
if tracked.closed.Load() == 0 {
|
|
t.Fatal("expected local agent connection to be closed")
|
|
}
|
|
if channel.closed.Load() == 0 {
|
|
t.Fatal("expected ssh channel to be closed")
|
|
}
|
|
}
|
|
|
|
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
|
agentConn, peerConn := net.Pipe()
|
|
defer peerConn.Close()
|
|
|
|
tracked := &trackedConn{Conn: agentConn}
|
|
channel := newBlockingTestSSHChannel()
|
|
proxy := &sshAgentForwardProxy{
|
|
stopCh: make(chan struct{}),
|
|
active: make(map[*sshAgentForwardBridge]struct{}),
|
|
}
|
|
bridge := &sshAgentForwardBridge{
|
|
proxy: proxy,
|
|
channel: channel,
|
|
conn: tracked,
|
|
}
|
|
if !proxy.registerBridge(bridge) {
|
|
t.Fatal("expected bridge registration to succeed")
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
bridge.run()
|
|
close(done)
|
|
}()
|
|
|
|
if err := proxy.Close(); err != nil {
|
|
t.Fatalf("close proxy: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("bridge did not exit after proxy close")
|
|
}
|
|
|
|
if tracked.closed.Load() == 0 {
|
|
t.Fatal("expected proxy close to close local agent connection")
|
|
}
|
|
if channel.closed.Load() == 0 {
|
|
t.Fatal("expected proxy close to close ssh channel")
|
|
}
|
|
}
|