starssh/agent_forward_test.go
starainrt 0c23e7d4bf
feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
2026-05-27 13:10:35 +08:00

666 lines
17 KiB
Go

package starssh
import (
"bytes"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type testCloser struct {
closed atomic.Int32
}
func (c *testCloser) Close() error {
c.closed.Add(1)
return nil
}
type trackedConn struct {
net.Conn
closed atomic.Int32
}
func (c *trackedConn) Close() error {
c.closed.Add(1)
if c.Conn == nil {
return nil
}
return c.Conn.Close()
}
type testSSHChannel struct {
readFunc func([]byte) (int, error)
stderr bytes.Buffer
closed atomic.Int32
closeOnce sync.Once
closeCh chan struct{}
}
type testNewChannel struct {
channel ssh.Channel
accepted atomic.Bool
rejected atomic.Bool
}
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
c.accepted.Store(true)
requests := make(chan *ssh.Request)
close(requests)
return c.channel, requests, nil
}
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
c.rejected.Store(true)
return nil
}
func (c *testNewChannel) ChannelType() string {
return sshAgentChannelType
}
func (c *testNewChannel) ExtraData() []byte {
return nil
}
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
return &testSSHChannel{
readFunc: readFunc,
closeCh: make(chan struct{}),
}
}
func newBlockingTestSSHChannel() *testSSHChannel {
ch := newTestSSHChannel(nil)
ch.readFunc = func(p []byte) (int, error) {
<-ch.closeCh
return 0, io.EOF
}
return ch
}
func (c *testSSHChannel) Read(p []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
if c.readFunc != nil {
return c.readFunc(p)
}
return 0, io.EOF
}
func (c *testSSHChannel) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testSSHChannel) Close() error {
if c == nil {
return nil
}
c.closeOnce.Do(func() {
c.closed.Add(1)
close(c.closeCh)
})
return nil
}
func (c *testSSHChannel) CloseWrite() error {
return nil
}
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return false, nil
}
func (c *testSSHChannel) Stderr() io.ReadWriter {
return &c.stderr
}
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
Timeout: time.Second,
SSHAgentTimeout: 3 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
},
}
star.setTransport(baseClient, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
return &ssh.Session{}, nil
}
var probeCalls atomic.Int32
closer := &testCloser{}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
probeCalls.Add(1)
if timeouts.Dial != time.Second {
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
}
return nil
}
var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
routeCalls.Add(1)
if client != baseClient {
t.Fatalf("unexpected routed client %p", client)
}
if timeouts.Dial != time.Second {
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
}
return closer, nil
}
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
if session == nil {
t.Fatal("expected non-nil ssh session")
}
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("first exec session: %v", err)
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("second exec session: %v", err)
}
if probeCalls.Load() != 1 {
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
}
if routeCalls.Load() != 1 {
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
}
if requestCalls.Load() != 2 {
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
}
closeSSHClient = func(client sshClientRequester) error { return nil }
if err := star.Close(); err != nil {
t.Fatalf("close starssh: %v", err)
}
if closer.closed.Load() != 1 {
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
}
}
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
oldNewSSHSession := newSSHSession
oldRequestSessionPTY := requestSessionPTY
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
requestSessionPTY = oldRequestSessionPTY
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
var ptyCalls atomic.Int32
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
ptyCalls.Add(1)
return nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
return nil
}
if _, err := star.NewPTYSession(nil); err != nil {
t.Fatalf("new pty session: %v", err)
}
if ptyCalls.Load() != 1 {
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
}
if requestCalls.Load() != 1 {
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
}
}
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
t.Fatal("agent forwarding probe should not run when disabled")
return nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("agent forwarding should not be requested when disabled")
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session without forwarding: %v", err)
}
}
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("session request should not run when agent forwarder init fails")
return nil
}
err := star.RequestAgentForwarding(&ssh.Session{})
if err == nil {
t.Fatal("expected agent forwarding init error")
}
}
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingUnavailableError(err) {
t.Fatalf("expected unavailable error, got %v", err)
}
}
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingDeniedError(err) {
t.Fatalf("expected forwarding denied error, got %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
oldNewSSHSession := newSSHSession
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore agent setup error: %v", err)
}
}
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
started := make(chan struct{})
release := make(chan struct{})
closer := &testCloser{}
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
close(started)
<-release
return closer, nil
}
closeSSHClient = func(client sshClientRequester) error { return nil }
errCh := make(chan error, 1)
go func() {
errCh <- star.ensureAgentForwarding()
}()
<-started
closeDone := make(chan struct{})
go func() {
_ = star.Close()
close(closeDone)
}()
deadline := time.Now().Add(time.Second)
for !star.closing.Load() {
if time.Now().After(deadline) {
t.Fatal("close did not enter closing state in time")
}
time.Sleep(time.Millisecond)
}
close(release)
err := <-errCh
if !errors.Is(err, errSSHClientClosing) {
t.Fatalf("expected closing error, got %v", err)
}
<-closeDone
if closer.closed.Load() != 1 {
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
}
if got := star.takeAgentForwarder(); got != nil {
t.Fatal("expected no leaked agent forwarder after close race")
}
}
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newTestSSHChannel(func(p []byte) (int, error) {
return 0, io.EOF
})
done := make(chan struct{})
go func() {
proxySSHAgentChannel(channel, tracked)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected local agent connection to be closed")
}
if channel.closed.Load() == 0 {
t.Fatal("expected ssh channel to be closed")
}
}
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newBlockingTestSSHChannel()
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: tracked,
}
if !proxy.registerBridge(bridge) {
t.Fatal("expected bridge registration to succeed")
}
done := make(chan struct{})
go func() {
bridge.run()
close(done)
}()
if err := proxy.Close(); err != nil {
t.Fatalf("close proxy: %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("bridge did not exit after proxy close")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected proxy close to close local agent connection")
}
if channel.closed.Load() == 0 {
t.Fatal("expected proxy close to close ssh channel")
}
}
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return tracked, nil
}
channel := newBlockingTestSSHChannel()
newChannel := &testNewChannel{
channel: channel,
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
Endpoint: "/tmp/agent.sock",
Forward: 20 * time.Millisecond,
})
if !newChannel.accepted.Load() {
t.Fatal("expected channel to be accepted")
}
waitUntil(t, time.Second, func() bool {
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
}, "forwarded agent bridge did not close both sides after idle timeout")
waitUntil(t, time.Second, func() bool {
proxy.activeMu.Lock()
defer proxy.activeMu.Unlock()
return len(proxy.active) == 0
}, "forwarded agent bridge did not unregister after idle timeout")
}
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(time.Millisecond)
}
t.Fatal(message)
}