fix: 重构 ssh agent forwarding 转发代理并修复资源残留
- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥 - 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册 - 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接 - 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期 - 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留 - 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为 - 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
This commit is contained in:
+216
-12
@@ -4,7 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@@ -15,24 +17,53 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return sshagent.RequestAgentForwarding(session)
|
||||
}
|
||||
|
||||
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
return sshagent.ForwardToAgent(client, keyring)
|
||||
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||
|
||||
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
return startSSHAgentForwardProxy(client, timeout)
|
||||
}
|
||||
|
||||
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
var probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
conn, err := dialSSHAgent(timeout)
|
||||
if err != nil {
|
||||
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||
}
|
||||
return sshagent.NewClient(conn), conn, nil
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||
|
||||
type sshAgentForwardProxy struct {
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
|
||||
activeMu sync.Mutex
|
||||
active map[*sshAgentForwardBridge]struct{}
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) Close() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
p.stopOnce.Do(func() {
|
||||
close(p.stopCh)
|
||||
})
|
||||
p.closeActive()
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshAgentForwardBridge struct {
|
||||
proxy *sshAgentForwardProxy
|
||||
channel ssh.Channel
|
||||
conn net.Conn
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||
if s == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
@@ -80,20 +111,21 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
||||
return err
|
||||
}
|
||||
|
||||
keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
|
||||
if err != nil {
|
||||
timeout := effectiveDialTimeout(s.LoginInfo)
|
||||
if err := probeSSHAgentForwarding(timeout); err != nil {
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if s.closing.Load() {
|
||||
_ = closer.Close()
|
||||
return errSSHClientClosing
|
||||
}
|
||||
if err := routeSSHAgentForwarding(client, keyring); err != nil {
|
||||
_ = closer.Close()
|
||||
closer, err := routeSSHAgentForwarding(client, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !s.canAttachAgentForwarder(client) {
|
||||
_ = closer.Close()
|
||||
if closer != nil {
|
||||
_ = closer.Close()
|
||||
}
|
||||
return errSSHClientClosing
|
||||
}
|
||||
s.agentForwarder = closer
|
||||
@@ -149,3 +181,175 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||
}
|
||||
|
||||
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||
if channels == nil {
|
||||
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||
}
|
||||
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-proxy.stopCh:
|
||||
return
|
||||
case ch, ok := <-channels:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
go handleSSHAgentForwardChannel(proxy, ch, timeout)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
conn, err := dialSSHAgent(timeout)
|
||||
if err != nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
channel, reqs, err := ch.Accept()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
bridge := &sshAgentForwardBridge{
|
||||
proxy: proxy,
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
}
|
||||
if !proxy.registerBridge(bridge) {
|
||||
bridge.close()
|
||||
return
|
||||
}
|
||||
go bridge.run()
|
||||
}
|
||||
|
||||
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||
bridge := &sshAgentForwardBridge{
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
}
|
||||
bridge.run()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) run() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
defer b.unregister()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(b.channel, b.conn)
|
||||
b.close()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(b.conn, b.channel)
|
||||
b.close()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) close() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.closeOnce.Do(func() {
|
||||
closeWriter(b.channel)
|
||||
closeWriter(b.conn)
|
||||
if b.channel != nil {
|
||||
_ = b.channel.Close()
|
||||
}
|
||||
if b.conn != nil {
|
||||
_ = b.conn.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) unregister() {
|
||||
if b == nil || b.proxy == nil {
|
||||
return
|
||||
}
|
||||
b.proxy.unregisterBridge(b)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||
if p == nil || bridge == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
if p.active == nil {
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
}
|
||||
p.active[bridge] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||
if p == nil || bridge == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
delete(p.active, bridge)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) closeActive() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||
for bridge := range p.active {
|
||||
active = append(active, bridge)
|
||||
}
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
p.activeMu.Unlock()
|
||||
|
||||
for _, bridge := range active {
|
||||
bridge.close()
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(value any) {
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
if cw, ok := value.(closeWriter); ok {
|
||||
_ = cw.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user