fix: 重构 ssh agent forwarding 转发代理并修复资源残留
- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥 - 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册 - 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接 - 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期 - 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留 - 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为 - 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
This commit is contained in:
parent
1625997d8f
commit
ad7c8b0587
226
agent_forward.go
226
agent_forward.go
@ -4,7 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -15,24 +17,53 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return sshagent.RequestAgentForwarding(session)
|
||||
}
|
||||
|
||||
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
return sshagent.ForwardToAgent(client, keyring)
|
||||
const sshAgentChannelType = "auth-agent@openssh.com"
|
||||
|
||||
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
return startSSHAgentForwardProxy(client, timeout)
|
||||
}
|
||||
|
||||
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
var probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
conn, err := dialSSHAgent(timeout)
|
||||
if err != nil {
|
||||
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||||
}
|
||||
return sshagent.NewClient(conn), conn, nil
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||||
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||||
|
||||
type sshAgentForwardProxy struct {
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
|
||||
activeMu sync.Mutex
|
||||
active map[*sshAgentForwardBridge]struct{}
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) Close() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
p.stopOnce.Do(func() {
|
||||
close(p.stopCh)
|
||||
})
|
||||
p.closeActive()
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshAgentForwardBridge struct {
|
||||
proxy *sshAgentForwardProxy
|
||||
channel ssh.Channel
|
||||
conn net.Conn
|
||||
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||||
if s == nil {
|
||||
return errors.New("ssh client is nil")
|
||||
@ -80,20 +111,21 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
||||
return err
|
||||
}
|
||||
|
||||
keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
|
||||
if err != nil {
|
||||
timeout := effectiveDialTimeout(s.LoginInfo)
|
||||
if err := probeSSHAgentForwarding(timeout); err != nil {
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
if s.closing.Load() {
|
||||
_ = closer.Close()
|
||||
return errSSHClientClosing
|
||||
}
|
||||
if err := routeSSHAgentForwarding(client, keyring); err != nil {
|
||||
_ = closer.Close()
|
||||
closer, err := routeSSHAgentForwarding(client, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !s.canAttachAgentForwarder(client) {
|
||||
if closer != nil {
|
||||
_ = closer.Close()
|
||||
}
|
||||
return errSSHClientClosing
|
||||
}
|
||||
s.agentForwarder = closer
|
||||
@ -149,3 +181,175 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
|
||||
}
|
||||
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||||
}
|
||||
|
||||
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
}
|
||||
channels := client.HandleChannelOpen(sshAgentChannelType)
|
||||
if channels == nil {
|
||||
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
|
||||
}
|
||||
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-proxy.stopCh:
|
||||
return
|
||||
case ch, ok := <-channels:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
go handleSSHAgentForwardChannel(proxy, ch, timeout)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
conn, err := dialSSHAgent(timeout)
|
||||
if err != nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
channel, reqs, err := ch.Accept()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
bridge := &sshAgentForwardBridge{
|
||||
proxy: proxy,
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
}
|
||||
if !proxy.registerBridge(bridge) {
|
||||
bridge.close()
|
||||
return
|
||||
}
|
||||
go bridge.run()
|
||||
}
|
||||
|
||||
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
|
||||
bridge := &sshAgentForwardBridge{
|
||||
channel: channel,
|
||||
conn: conn,
|
||||
}
|
||||
bridge.run()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) run() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
defer b.unregister()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(b.channel, b.conn)
|
||||
b.close()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(b.conn, b.channel)
|
||||
b.close()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) close() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.closeOnce.Do(func() {
|
||||
closeWriter(b.channel)
|
||||
closeWriter(b.conn)
|
||||
if b.channel != nil {
|
||||
_ = b.channel.Close()
|
||||
}
|
||||
if b.conn != nil {
|
||||
_ = b.conn.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (b *sshAgentForwardBridge) unregister() {
|
||||
if b == nil || b.proxy == nil {
|
||||
return
|
||||
}
|
||||
b.proxy.unregisterBridge(b)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
|
||||
if p == nil || bridge == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
if p.active == nil {
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
}
|
||||
p.active[bridge] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
|
||||
if p == nil || bridge == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
defer p.activeMu.Unlock()
|
||||
|
||||
delete(p.active, bridge)
|
||||
}
|
||||
|
||||
func (p *sshAgentForwardProxy) closeActive() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.activeMu.Lock()
|
||||
active := make([]*sshAgentForwardBridge, 0, len(p.active))
|
||||
for bridge := range p.active {
|
||||
active = append(active, bridge)
|
||||
}
|
||||
p.active = make(map[*sshAgentForwardBridge]struct{})
|
||||
p.activeMu.Unlock()
|
||||
|
||||
for _, bridge := range active {
|
||||
bridge.close()
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(value any) {
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
if cw, ok := value.(closeWriter); ok {
|
||||
_ = cw.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
sshagent "golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
type testCloser struct {
|
||||
@ -20,15 +22,90 @@ func (c *testCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
closed atomic.Int32
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.closed.Add(1)
|
||||
if c.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type testSSHChannel struct {
|
||||
readFunc func([]byte) (int, error)
|
||||
|
||||
stderr bytes.Buffer
|
||||
closed atomic.Int32
|
||||
closeOnce sync.Once
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
|
||||
return &testSSHChannel{
|
||||
readFunc: readFunc,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func newBlockingTestSSHChannel() *testSSHChannel {
|
||||
ch := newTestSSHChannel(nil)
|
||||
ch.readFunc = func(p []byte) (int, error) {
|
||||
<-ch.closeCh
|
||||
return 0, io.EOF
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Read(p []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if c.readFunc != nil {
|
||||
return c.readFunc(p)
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.closeOnce.Do(func() {
|
||||
c.closed.Add(1)
|
||||
close(c.closeCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *testSSHChannel) Stderr() io.ReadWriter {
|
||||
return &c.stderr
|
||||
}
|
||||
|
||||
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
@ -50,26 +127,26 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
|
||||
var agentInitCalls atomic.Int32
|
||||
var probeCalls atomic.Int32
|
||||
closer := &testCloser{}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
agentInitCalls.Add(1)
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
probeCalls.Add(1)
|
||||
if timeout != time.Second {
|
||||
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
||||
}
|
||||
return sshagent.NewKeyring(), closer, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var routeCalls atomic.Int32
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
routeCalls.Add(1)
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected routed client %p", client)
|
||||
}
|
||||
if keyring == nil {
|
||||
t.Fatal("expected non-nil forwarded agent keyring")
|
||||
if timeout != time.Second {
|
||||
t.Fatalf("unexpected routed timeout: %v", timeout)
|
||||
}
|
||||
return nil
|
||||
return closer, nil
|
||||
}
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
@ -88,8 +165,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
t.Fatalf("second exec session: %v", err)
|
||||
}
|
||||
|
||||
if agentInitCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
|
||||
if probeCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
|
||||
}
|
||||
if routeCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||
@ -110,13 +187,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldRequestSessionPTY := requestSessionPTY
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
requestSessionPTY = oldRequestSessionPTY
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
@ -138,10 +215,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
@ -162,11 +241,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
|
||||
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
@ -176,9 +255,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
t.Fatal("agent forwarder should not initialize when disabled")
|
||||
return nil, nil, nil
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
t.Fatal("agent forwarding probe should not run when disabled")
|
||||
return nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("agent forwarding should not be requested when disabled")
|
||||
@ -191,18 +270,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("session request should not run when agent forwarder init fails")
|
||||
@ -216,16 +295,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
@ -235,11 +314,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
@ -247,10 +326,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
@ -263,12 +344,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
@ -283,10 +364,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
return &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
@ -298,10 +381,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
@ -314,8 +397,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
@ -325,10 +408,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
@ -341,8 +424,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
@ -351,11 +434,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldProbeSSHAgentForwarding := probeSSHAgentForwarding
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
probeSSHAgentForwarding = oldProbeSSHAgentForwarding
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
@ -370,13 +453,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
closer := &testCloser{}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
probeSSHAgentForwarding = func(timeout time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return sshagent.NewKeyring(), closer, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
return nil
|
||||
return closer, nil
|
||||
}
|
||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||
|
||||
@ -415,3 +498,75 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
t.Fatal("expected no leaked agent forwarder after close race")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
channel := newTestSSHChannel(func(p []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
proxySSHAgentChannel(channel, tracked)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
|
||||
}
|
||||
|
||||
if tracked.closed.Load() == 0 {
|
||||
t.Fatal("expected local agent connection to be closed")
|
||||
}
|
||||
if channel.closed.Load() == 0 {
|
||||
t.Fatal("expected ssh channel to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
|
||||
agentConn, peerConn := net.Pipe()
|
||||
defer peerConn.Close()
|
||||
|
||||
tracked := &trackedConn{Conn: agentConn}
|
||||
channel := newBlockingTestSSHChannel()
|
||||
proxy := &sshAgentForwardProxy{
|
||||
stopCh: make(chan struct{}),
|
||||
active: make(map[*sshAgentForwardBridge]struct{}),
|
||||
}
|
||||
bridge := &sshAgentForwardBridge{
|
||||
proxy: proxy,
|
||||
channel: channel,
|
||||
conn: tracked,
|
||||
}
|
||||
if !proxy.registerBridge(bridge) {
|
||||
t.Fatal("expected bridge registration to succeed")
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
bridge.run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if err := proxy.Close(); err != nil {
|
||||
t.Fatalf("close proxy: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("bridge did not exit after proxy close")
|
||||
}
|
||||
|
||||
if tracked.closed.Load() == 0 {
|
||||
t.Fatal("expected proxy close to close local agent connection")
|
||||
}
|
||||
if channel.closed.Load() == 0 {
|
||||
t.Fatal("expected proxy close to close ssh channel")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user