fix: 拆分 starssh 的拨号超时与认证超时语义

- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】
- 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值
- 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑
- 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时
- 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout
- 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题
- 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout
- 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
This commit is contained in:
兔子 2026-04-26 23:29:36 +08:00
parent b29246a9c4
commit 1625997d8f
Signed by: b612
GPG Key ID: 99DD2222B612B612
5 changed files with 143 additions and 18 deletions

View File

@ -80,7 +80,7 @@ func (s *StarSSH) ensureAgentForwarding() error {
return err return err
} }
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout) keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
if err != nil { if err != nil {
return wrapSSHAgentForwardingUnavailable(err) return wrapSSHAgentForwardingUnavailable(err)
} }

View File

@ -16,6 +16,7 @@ import (
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
var defaultAuthOrder = []AuthMethodKind{ var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent, AuthMethodSSHAgent,
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return nil, ErrHostKeyCallbackRequired return nil, ErrHostKeyCallbackRequired
} }
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout) authTimeout := effectiveLoginTimeout(info)
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
defer cancel() defer cancel()
sshInfo := &StarSSH{ sshInfo := &StarSSH{
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
clientConfig := &ssh.ClientConfig{ clientConfig := &ssh.ClientConfig{
User: info.User, User: info.User,
Auth: auth, Auth: auth,
Timeout: info.Timeout, Timeout: authTimeout,
HostKeyCallback: hostKeyCallback, HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback, BannerCallback: bannerCallback,
} }
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
if err != nil { if err != nil {
return sshInfo, err return sshInfo, err
} }
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout) restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
defer restoreDeadline() defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
Addr: host, Addr: host,
Port: port, Port: port,
Timeout: timeout, Timeout: timeout,
DialTimeout: timeout,
User: user, User: user,
HostKeyCallback: DefaultAllowHostKeyCallback, HostKeyCallback: DefaultAllowHostKeyCallback,
} }
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 { if info.Port <= 0 {
info.Port = defaultSSHPort info.Port = defaultSSHPort
} }
if info.Timeout <= 0 {
info.Timeout = defaultLoginTimeout
}
return info return info
} }
func effectiveLoginTimeout(info LoginInput) time.Duration {
if info.Timeout <= 0 {
return 0
}
return info.Timeout
}
func effectiveDialTimeout(info LoginInput) time.Duration {
switch {
case info.DialTimeout < 0:
return 0
case info.DialTimeout > 0:
return info.DialTimeout
case info.Timeout > 0:
return info.Timeout
default:
return defaultLoginTimeout
}
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder) order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil { if err != nil {
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
if info.DisableSSHAgent { if info.DisableSSHAgent {
continue continue
} }
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout) agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
if err != nil { if err != nil {
agentErr = err agentErr = err
continue continue

99
login_timeout_test.go Normal file
View File

@ -0,0 +1,99 @@
package starssh
import (
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
info := normalizeLoginInput(LoginInput{})
if info.Port != defaultSSHPort {
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
}
if info.Timeout != 0 {
t.Fatalf("Timeout=%v want 0", info.Timeout)
}
if info.DialTimeout != 0 {
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
}
}
func TestEffectiveLoginTimeout(t *testing.T) {
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero login timeout should stay zero, got %v", got)
}
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
t.Fatalf("expected explicit login timeout, got %v", got)
}
}
func TestEffectiveDialTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback",
info: LoginInput{},
want: defaultLoginTimeout,
},
{
name: "reuse timeout when dial timeout omitted",
info: LoginInput{Timeout: 9 * time.Second},
want: 9 * time.Second,
},
{
name: "explicit dial timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
want: 3 * time.Second,
},
{
name: "negative dial timeout disables default dial deadline",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveDialTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
captured := time.Duration(-2)
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
captured = timeout
return ssh.Password("agent"), nil, nil
}
info := LoginInput{
Timeout: 0,
DialTimeout: 11 * time.Second,
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}
auth, cleanup, err := buildAuthMethods(info)
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("expected one auth method, got %d", len(auth))
}
if captured != 11*time.Second {
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
}
}

View File

@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
} }
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: info.Timeout, Timeout: effectiveDialTimeout(info),
} }
return dialer.DialContext return dialer.DialContext
} }
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
} }
dialContext := resolveDialContext(info) dialContext := resolveDialContext(info)
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout) proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
if proxyConfig != nil { if proxyConfig != nil {
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr) return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
} }

View File

@ -94,7 +94,13 @@ type LoginInput struct {
AuthOrder []AuthMethodKind AuthOrder []AuthMethodKind
Addr string Addr string
Port int Port int
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
// already been established. Zero means no authentication timeout.
Timeout time.Duration Timeout time.Duration
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
// uses the package default dial timeout. Negative disables the default dial timeout.
DialTimeout time.Duration
DialContext DialContextFunc DialContext DialContextFunc
Proxy *ProxyConfig Proxy *ProxyConfig
Jump *LoginInput Jump *LoginInput