feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑 - 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件 - 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录 - 支持 agent signer 失败后跳过坏 key 并重试后续 key - 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求 - 增强 agent forwarding 的探测、通道空闲超时和关闭清理 - 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持 - 增加相关回归测试和 Windows 编译验证覆盖
This commit is contained in:
+114
-21
@@ -44,6 +44,32 @@ type testSSHChannel struct {
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type testNewChannel struct {
|
||||
channel ssh.Channel
|
||||
accepted atomic.Bool
|
||||
rejected atomic.Bool
|
||||
}
|
||||
|
||||
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
c.accepted.Store(true)
|
||||
requests := make(chan *ssh.Request)
|
||||
close(requests)
|
||||
return c.channel, requests, nil
|
||||
}
|
||||
|
||||
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
|
||||
c.rejected.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testNewChannel) ChannelType() string {
|
||||
return sshAgentChannelType
|
||||
}
|
||||
|
||||
func (c *testNewChannel) ExtraData() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||
return &testSSHChannel{
|
||||
readFunc: readFunc,
|
||||
@@ -114,8 +140,10 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
Timeout: time.Second,
|
||||
ForwardSSHAgent: true,
|
||||
Timeout: time.Second,
|
||||
SSHAgentTimeout: 3 * time.Second,
|
||||
SSHAgentForwardTimeout: 4 * time.Second,
|
||||
},
|
||||
}
|
||||
star.setTransport(baseClient, nil)
|
||||
@@ -129,22 +157,34 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
|
||||
var probeCalls atomic.Int32
|
||||
closer := &testCloser{}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
probeCalls.Add(1)
|
||||
if timeout != time.Second {
|
||||
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
||||
if timeouts.Dial != time.Second {
|
||||
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
|
||||
}
|
||||
if timeouts.Operation != 3*time.Second {
|
||||
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
|
||||
}
|
||||
if timeouts.Forward != 4*time.Second {
|
||||
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var routeCalls atomic.Int32
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
routeCalls.Add(1)
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected routed client %p", client)
|
||||
}
|
||||
if timeout != time.Second {
|
||||
t.Fatalf("unexpected routed timeout: %v", timeout)
|
||||
if timeouts.Dial != time.Second {
|
||||
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
|
||||
}
|
||||
if timeouts.Operation != 3*time.Second {
|
||||
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
|
||||
}
|
||||
if timeouts.Forward != 4*time.Second {
|
||||
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
|
||||
}
|
||||
return closer, nil
|
||||
}
|
||||
@@ -215,10 +255,10 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
|
||||
@@ -255,7 +295,7 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
t.Fatal("agent forwarding probe should not run when disabled")
|
||||
return nil
|
||||
}
|
||||
@@ -280,7 +320,7 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
@@ -303,7 +343,7 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||
}
|
||||
|
||||
@@ -326,10 +366,10 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
@@ -364,10 +404,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
@@ -397,7 +437,7 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
|
||||
@@ -424,7 +464,7 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||
}
|
||||
|
||||
@@ -453,10 +493,10 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
closer := &testCloser{}
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return closer, nil
|
||||
@@ -570,3 +610,56 @@ func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||
t.Fatal("expected proxy close to close ssh channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
|
||||
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
||||
t.Cleanup(func() {
|
||||
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
||||
})
|
||||
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
||||
return tracked, nil
|
||||
}
|
||||
|
||||
channel := newBlockingTestSSHChannel()
|
||||
newChannel := &testNewChannel{
|
||||
channel: channel,
|
||||
}
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
|
||||
Endpoint: "/tmp/agent.sock",
|
||||
Forward: 20 * time.Millisecond,
|
||||
})
|
||||
|
||||
if !newChannel.accepted.Load() {
|
||||
t.Fatal("expected channel to be accepted")
|
||||
}
|
||||
|
||||
waitUntil(t, time.Second, func() bool {
|
||||
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
|
||||
}, "forwarded agent bridge did not close both sides after idle timeout")
|
||||
|
||||
waitUntil(t, time.Second, func() bool {
|
||||
proxy.activeMu.Lock()
|
||||
defer proxy.activeMu.Unlock()
|
||||
return len(proxy.active) == 0
|
||||
}, "forwarded agent bridge did not unregister after idle timeout")
|
||||
}
|
||||
|
||||
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
t.Fatal(message)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user