Compare commits

...

2 Commits

Author SHA1 Message Date
ad7c8b0587
fix: 重构 ssh agent forwarding 转发代理并修复资源残留
- 将 agent forwarding 从长持有本地 agent 连接改为按 uth-agent@openssh.com channel 动态建桥
- 新增 ssh-agent 可用性探测流程,区分 forwarding 探测与实际转发注册
- 重构 forwarding 注册接口,按连接超时创建本地 agent bridge,不再复用单个长期占用的 agent 连接
- 新增 sshAgentForwardProxy 与 sshAgentForwardBridge,显式管理活跃 agent bridge 生命周期
- 在远端 channel 单侧 EOF、bridge 关闭和 proxy.Close() 路径上主动关闭本地 agent 连接与 SSH channel,避免 goroutine 和 agent 句柄残留
- 保留现有 denied / unavailable / close-race 语义,并继续保证自动 forwarding 的 best-effort 行为
- 扩充 agent forwarding 回归测试,覆盖单次启用、禁用、denied、unavailable、close race、单侧 EOF 释放以及 proxy Close 主动回收活跃 bridge 等关键场景
2026-04-27 00:06:32 +08:00
1625997d8f
fix: 拆分 starssh 的拨号超时与认证超时语义
- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】
- 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值
- 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑
- 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时
- 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout
- 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题
- 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout
- 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
2026-04-26 23:29:36 +08:00
6 changed files with 569 additions and 85 deletions

View File

@ -4,7 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -15,24 +17,53 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
return sshagent.RequestAgentForwarding(session) return sshagent.RequestAgentForwarding(session)
} }
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { const sshAgentChannelType = "auth-agent@openssh.com"
return sshagent.ForwardToAgent(client, keyring)
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return startSSHAgentForwardProxy(client, timeout)
} }
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { var probeSSHAgentForwarding = func(timeout time.Duration) error {
conn, err := dialSSHAgent(timeout) conn, err := dialSSHAgent(timeout)
if err != nil { if err != nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(err) return wrapSSHAgentForwardingUnavailable(err)
} }
if conn == nil { if conn == nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
} }
return sshagent.NewClient(conn), conn, nil return conn.Close()
} }
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
type sshAgentForwardProxy struct {
stopOnce sync.Once
stopCh chan struct{}
activeMu sync.Mutex
active map[*sshAgentForwardBridge]struct{}
}
func (p *sshAgentForwardProxy) Close() error {
if p == nil {
return nil
}
p.stopOnce.Do(func() {
close(p.stopCh)
})
p.closeActive()
return nil
}
type sshAgentForwardBridge struct {
proxy *sshAgentForwardProxy
channel ssh.Channel
conn net.Conn
closeOnce sync.Once
}
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
if s == nil { if s == nil {
return errors.New("ssh client is nil") return errors.New("ssh client is nil")
@ -80,20 +111,21 @@ func (s *StarSSH) ensureAgentForwarding() error {
return err return err
} }
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout) timeout := effectiveDialTimeout(s.LoginInfo)
if err != nil { if err := probeSSHAgentForwarding(timeout); err != nil {
return wrapSSHAgentForwardingUnavailable(err) return wrapSSHAgentForwardingUnavailable(err)
} }
if s.closing.Load() { if s.closing.Load() {
_ = closer.Close()
return errSSHClientClosing return errSSHClientClosing
} }
if err := routeSSHAgentForwarding(client, keyring); err != nil { closer, err := routeSSHAgentForwarding(client, timeout)
_ = closer.Close() if err != nil {
return err return err
} }
if !s.canAttachAgentForwarder(client) { if !s.canAttachAgentForwarder(client) {
if closer != nil {
_ = closer.Close() _ = closer.Close()
}
return errSSHClientClosing return errSSHClientClosing
} }
s.agentForwarder = closer s.agentForwarder = closer
@ -149,3 +181,175 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
} }
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
} }
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
channels := client.HandleChannelOpen(sshAgentChannelType)
if channels == nil {
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
go func() {
for {
select {
case <-proxy.stopCh:
return
case ch, ok := <-channels:
if !ok {
return
}
go handleSSHAgentForwardChannel(proxy, ch, timeout)
}
}
}()
return proxy, nil
}
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
if ch == nil {
return
}
conn, err := dialSSHAgent(timeout)
if err != nil {
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
return
}
if conn == nil {
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
return
}
channel, reqs, err := ch.Accept()
if err != nil {
_ = conn.Close()
return
}
go ssh.DiscardRequests(reqs)
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: conn,
}
if !proxy.registerBridge(bridge) {
bridge.close()
return
}
go bridge.run()
}
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
bridge := &sshAgentForwardBridge{
channel: channel,
conn: conn,
}
bridge.run()
}
func (b *sshAgentForwardBridge) run() {
if b == nil {
return
}
defer b.unregister()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(b.channel, b.conn)
b.close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(b.conn, b.channel)
b.close()
}()
wg.Wait()
}
func (b *sshAgentForwardBridge) close() {
if b == nil {
return
}
b.closeOnce.Do(func() {
closeWriter(b.channel)
closeWriter(b.conn)
if b.channel != nil {
_ = b.channel.Close()
}
if b.conn != nil {
_ = b.conn.Close()
}
})
}
func (b *sshAgentForwardBridge) unregister() {
if b == nil || b.proxy == nil {
return
}
b.proxy.unregisterBridge(b)
}
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
if p == nil || bridge == nil {
return false
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
select {
case <-p.stopCh:
return false
default:
}
if p.active == nil {
p.active = make(map[*sshAgentForwardBridge]struct{})
}
p.active[bridge] = struct{}{}
return true
}
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
if p == nil || bridge == nil {
return
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
delete(p.active, bridge)
}
func (p *sshAgentForwardProxy) closeActive() {
if p == nil {
return
}
p.activeMu.Lock()
active := make([]*sshAgentForwardBridge, 0, len(p.active))
for bridge := range p.active {
active = append(active, bridge)
}
p.active = make(map[*sshAgentForwardBridge]struct{})
p.activeMu.Unlock()
for _, bridge := range active {
bridge.close()
}
}
func closeWriter(value any) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := value.(closeWriter); ok {
_ = cw.CloseWrite()
}
}

View File

@ -1,14 +1,16 @@
package starssh package starssh
import ( import (
"bytes"
"errors" "errors"
"io" "io"
"net"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
) )
type testCloser struct { type testCloser struct {
@ -20,15 +22,90 @@ func (c *testCloser) Close() error {
return nil return nil
} }
type trackedConn struct {
net.Conn
closed atomic.Int32
}
func (c *trackedConn) Close() error {
c.closed.Add(1)
if c.Conn == nil {
return nil
}
return c.Conn.Close()
}
type testSSHChannel struct {
readFunc func([]byte) (int, error)
stderr bytes.Buffer
closed atomic.Int32
closeOnce sync.Once
closeCh chan struct{}
}
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
return &testSSHChannel{
readFunc: readFunc,
closeCh: make(chan struct{}),
}
}
func newBlockingTestSSHChannel() *testSSHChannel {
ch := newTestSSHChannel(nil)
ch.readFunc = func(p []byte) (int, error) {
<-ch.closeCh
return 0, io.EOF
}
return ch
}
func (c *testSSHChannel) Read(p []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
if c.readFunc != nil {
return c.readFunc(p)
}
return 0, io.EOF
}
func (c *testSSHChannel) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testSSHChannel) Close() error {
if c == nil {
return nil
}
c.closeOnce.Do(func() {
c.closed.Add(1)
close(c.closeCh)
})
return nil
}
func (c *testSSHChannel) CloseWrite() error {
return nil
}
func (c *testSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return false, nil
}
func (c *testSSHChannel) Stderr() io.ReadWriter {
return &c.stderr
}
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
oldCloseSSHClient := closeSSHClient oldCloseSSHClient := closeSSHClient
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
closeSSHClient = oldCloseSSHClient closeSSHClient = oldCloseSSHClient
@ -50,26 +127,26 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
return &ssh.Session{}, nil return &ssh.Session{}, nil
} }
var agentInitCalls atomic.Int32 var probeCalls atomic.Int32
closer := &testCloser{} closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
agentInitCalls.Add(1) probeCalls.Add(1)
if timeout != time.Second { if timeout != time.Second {
t.Fatalf("unexpected forwarding timeout: %v", timeout) t.Fatalf("unexpected forwarding timeout: %v", timeout)
} }
return sshagent.NewKeyring(), closer, nil return nil
} }
var routeCalls atomic.Int32 var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeCalls.Add(1) routeCalls.Add(1)
if client != baseClient { if client != baseClient {
t.Fatalf("unexpected routed client %p", client) t.Fatalf("unexpected routed client %p", client)
} }
if keyring == nil { if timeout != time.Second {
t.Fatal("expected non-nil forwarded agent keyring") t.Fatalf("unexpected routed timeout: %v", timeout)
} }
return nil return closer, nil
} }
var requestCalls atomic.Int32 var requestCalls atomic.Int32
@ -88,8 +165,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
t.Fatalf("second exec session: %v", err) t.Fatalf("second exec session: %v", err)
} }
if agentInitCalls.Load() != 1 { if probeCalls.Load() != 1 {
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load()) t.Fatalf("expected one agent probe, got %d", probeCalls.Load())
} }
if routeCalls.Load() != 1 { if routeCalls.Load() != 1 {
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load()) t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
@ -110,13 +187,13 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldRequestSessionPTY := requestSessionPTY oldRequestSessionPTY := requestSessionPTY
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
requestSessionPTY = oldRequestSessionPTY requestSessionPTY = oldRequestSessionPTY
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
}) })
@ -138,10 +215,12 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
return nil return nil
} }
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return sshagent.NewKeyring(), &testCloser{}, nil return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
} }
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
var requestCalls atomic.Int32 var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error { requestSSHAgentForwarding = func(session *ssh.Session) error {
@ -162,11 +241,11 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
}) })
@ -176,9 +255,9 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil return &ssh.Session{}, nil
} }
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
t.Fatal("agent forwarder should not initialize when disabled") t.Fatal("agent forwarding probe should not run when disabled")
return nil, nil, nil return nil
} }
requestSSHAgentForwarding = func(session *ssh.Session) error { requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("agent forwarding should not be requested when disabled") t.Fatal("agent forwarding should not be requested when disabled")
@ -191,18 +270,18 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
} }
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
}) })
star := &StarSSH{} star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil) star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
} }
requestSSHAgentForwarding = func(session *ssh.Session) error { requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("session request should not run when agent forwarder init fails") t.Fatal("session request should not run when agent forwarder init fails")
@ -216,16 +295,16 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
} }
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
}) })
star := &StarSSH{} star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil) star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
} }
err := star.RequestAgentForwarding(&ssh.Session{}) err := star.RequestAgentForwarding(&ssh.Session{})
@ -235,11 +314,11 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
} }
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
}) })
@ -247,10 +326,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
star := &StarSSH{} star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil) star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return sshagent.NewKeyring(), &testCloser{}, nil return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
} }
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error { requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied") return errors.New("forwarding request denied")
} }
@ -263,12 +344,12 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding requestSSHAgentForwarding = oldRequestSSHAgentForwarding
}) })
@ -283,10 +364,12 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil return &ssh.Session{}, nil
} }
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return sshagent.NewKeyring(), &testCloser{}, nil return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return &testCloser{}, nil
} }
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error { requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied") return errors.New("forwarding request denied")
} }
@ -298,10 +381,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
}) })
star := &StarSSH{ star := &StarSSH{
@ -314,8 +397,8 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil return &ssh.Session{}, nil
} }
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
} }
if _, err := star.NewExecSession(); err != nil { if _, err := star.NewExecSession(); err != nil {
@ -325,10 +408,10 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
oldNewSSHSession := newSSHSession oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
t.Cleanup(func() { t.Cleanup(func() {
newSSHSession = oldNewSSHSession newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
}) })
star := &StarSSH{ star := &StarSSH{
@ -341,8 +424,8 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil return &ssh.Session{}, nil
} }
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
} }
if _, err := star.NewExecSession(); err != nil { if _, err := star.NewExecSession(); err != nil {
@ -351,11 +434,11 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
} }
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder oldProbeSSHAgentForwarding := probeSSHAgentForwarding
oldRouteSSHAgentForwarding := routeSSHAgentForwarding oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldCloseSSHClient := closeSSHClient oldCloseSSHClient := closeSSHClient
t.Cleanup(func() { t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder probeSSHAgentForwarding = oldProbeSSHAgentForwarding
routeSSHAgentForwarding = oldRouteSSHAgentForwarding routeSSHAgentForwarding = oldRouteSSHAgentForwarding
closeSSHClient = oldCloseSSHClient closeSSHClient = oldCloseSSHClient
}) })
@ -370,13 +453,13 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
started := make(chan struct{}) started := make(chan struct{})
release := make(chan struct{}) release := make(chan struct{})
closer := &testCloser{} closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { probeSSHAgentForwarding = func(timeout time.Duration) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
close(started) close(started)
<-release <-release
return sshagent.NewKeyring(), closer, nil return closer, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
return nil
} }
closeSSHClient = func(client sshClientRequester) error { return nil } closeSSHClient = func(client sshClientRequester) error { return nil }
@ -415,3 +498,75 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
t.Fatal("expected no leaked agent forwarder after close race") t.Fatal("expected no leaked agent forwarder after close race")
} }
} }
func TestProxySSHAgentChannelClosesBlockedAgentConnWhenRemoteChannelEnds(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newTestSSHChannel(func(p []byte) (int, error) {
return 0, io.EOF
})
done := make(chan struct{})
go func() {
proxySSHAgentChannel(channel, tracked)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("proxySSHAgentChannel did not exit after remote EOF")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected local agent connection to be closed")
}
if channel.closed.Load() == 0 {
t.Fatal("expected ssh channel to be closed")
}
}
func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
channel := newBlockingTestSSHChannel()
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: tracked,
}
if !proxy.registerBridge(bridge) {
t.Fatal("expected bridge registration to succeed")
}
done := make(chan struct{})
go func() {
bridge.run()
close(done)
}()
if err := proxy.Close(); err != nil {
t.Fatalf("close proxy: %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("bridge did not exit after proxy close")
}
if tracked.closed.Load() == 0 {
t.Fatal("expected proxy close to close local agent connection")
}
if channel.closed.Load() == 0 {
t.Fatal("expected proxy close to close ssh channel")
}
}

View File

@ -16,6 +16,7 @@ import (
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
var defaultAuthOrder = []AuthMethodKind{ var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent, AuthMethodSSHAgent,
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return nil, ErrHostKeyCallbackRequired return nil, ErrHostKeyCallbackRequired
} }
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout) authTimeout := effectiveLoginTimeout(info)
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
defer cancel() defer cancel()
sshInfo := &StarSSH{ sshInfo := &StarSSH{
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
clientConfig := &ssh.ClientConfig{ clientConfig := &ssh.ClientConfig{
User: info.User, User: info.User,
Auth: auth, Auth: auth,
Timeout: info.Timeout, Timeout: authTimeout,
HostKeyCallback: hostKeyCallback, HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback, BannerCallback: bannerCallback,
} }
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
if err != nil { if err != nil {
return sshInfo, err return sshInfo, err
} }
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout) restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
defer restoreDeadline() defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
Addr: host, Addr: host,
Port: port, Port: port,
Timeout: timeout, Timeout: timeout,
DialTimeout: timeout,
User: user, User: user,
HostKeyCallback: DefaultAllowHostKeyCallback, HostKeyCallback: DefaultAllowHostKeyCallback,
} }
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 { if info.Port <= 0 {
info.Port = defaultSSHPort info.Port = defaultSSHPort
} }
if info.Timeout <= 0 {
info.Timeout = defaultLoginTimeout
}
return info return info
} }
func effectiveLoginTimeout(info LoginInput) time.Duration {
if info.Timeout <= 0 {
return 0
}
return info.Timeout
}
func effectiveDialTimeout(info LoginInput) time.Duration {
switch {
case info.DialTimeout < 0:
return 0
case info.DialTimeout > 0:
return info.DialTimeout
case info.Timeout > 0:
return info.Timeout
default:
return defaultLoginTimeout
}
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder) order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil { if err != nil {
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
if info.DisableSSHAgent { if info.DisableSSHAgent {
continue continue
} }
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout) agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
if err != nil { if err != nil {
agentErr = err agentErr = err
continue continue

99
login_timeout_test.go Normal file
View File

@ -0,0 +1,99 @@
package starssh
import (
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
info := normalizeLoginInput(LoginInput{})
if info.Port != defaultSSHPort {
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
}
if info.Timeout != 0 {
t.Fatalf("Timeout=%v want 0", info.Timeout)
}
if info.DialTimeout != 0 {
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
}
}
func TestEffectiveLoginTimeout(t *testing.T) {
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero login timeout should stay zero, got %v", got)
}
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
t.Fatalf("expected explicit login timeout, got %v", got)
}
}
func TestEffectiveDialTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback",
info: LoginInput{},
want: defaultLoginTimeout,
},
{
name: "reuse timeout when dial timeout omitted",
info: LoginInput{Timeout: 9 * time.Second},
want: 9 * time.Second,
},
{
name: "explicit dial timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
want: 3 * time.Second,
},
{
name: "negative dial timeout disables default dial deadline",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveDialTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
captured := time.Duration(-2)
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
captured = timeout
return ssh.Password("agent"), nil, nil
}
info := LoginInput{
Timeout: 0,
DialTimeout: 11 * time.Second,
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}
auth, cleanup, err := buildAuthMethods(info)
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("expected one auth method, got %d", len(auth))
}
if captured != 11*time.Second {
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
}
}

View File

@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
} }
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: info.Timeout, Timeout: effectiveDialTimeout(info),
} }
return dialer.DialContext return dialer.DialContext
} }
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
} }
dialContext := resolveDialContext(info) dialContext := resolveDialContext(info)
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout) proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
if proxyConfig != nil { if proxyConfig != nil {
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr) return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
} }

View File

@ -94,7 +94,13 @@ type LoginInput struct {
AuthOrder []AuthMethodKind AuthOrder []AuthMethodKind
Addr string Addr string
Port int Port int
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
// already been established. Zero means no authentication timeout.
Timeout time.Duration Timeout time.Duration
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
// uses the package default dial timeout. Negative disables the default dial timeout.
DialTimeout time.Duration
DialContext DialContextFunc DialContext DialContextFunc
Proxy *ProxyConfig Proxy *ProxyConfig
Jump *LoginInput Jump *LoginInput