fix: 拆分 starssh 的拨号超时与认证超时语义
- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】 - 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值 - 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑 - 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时 - 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout - 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题 - 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout - 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
This commit is contained in:
parent
b29246a9c4
commit
1625997d8f
@ -80,7 +80,7 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
||||
return err
|
||||
}
|
||||
|
||||
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
|
||||
keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
|
||||
if err != nil {
|
||||
return wrapSSHAgentForwardingUnavailable(err)
|
||||
}
|
||||
|
||||
34
login.go
34
login.go
@ -16,6 +16,7 @@ import (
|
||||
|
||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||
|
||||
var defaultAuthOrder = []AuthMethodKind{
|
||||
AuthMethodSSHAgent,
|
||||
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
return nil, ErrHostKeyCallbackRequired
|
||||
}
|
||||
|
||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
||||
authTimeout := effectiveLoginTimeout(info)
|
||||
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||
defer cancel()
|
||||
|
||||
sshInfo := &StarSSH{
|
||||
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: info.User,
|
||||
Auth: auth,
|
||||
Timeout: info.Timeout,
|
||||
Timeout: authTimeout,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
BannerCallback: bannerCallback,
|
||||
}
|
||||
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
||||
if err != nil {
|
||||
return sshInfo, err
|
||||
}
|
||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
|
||||
defer restoreDeadline()
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
|
||||
Addr: host,
|
||||
Port: port,
|
||||
Timeout: timeout,
|
||||
DialTimeout: timeout,
|
||||
User: user,
|
||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||
}
|
||||
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
|
||||
if info.Port <= 0 {
|
||||
info.Port = defaultSSHPort
|
||||
}
|
||||
if info.Timeout <= 0 {
|
||||
info.Timeout = defaultLoginTimeout
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||
if info.Timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
return info.Timeout
|
||||
}
|
||||
|
||||
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||
switch {
|
||||
case info.DialTimeout < 0:
|
||||
return 0
|
||||
case info.DialTimeout > 0:
|
||||
return info.DialTimeout
|
||||
case info.Timeout > 0:
|
||||
return info.Timeout
|
||||
default:
|
||||
return defaultLoginTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||
if err != nil {
|
||||
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||
if info.DisableSSHAgent {
|
||||
continue
|
||||
}
|
||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
||||
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
|
||||
if err != nil {
|
||||
agentErr = err
|
||||
continue
|
||||
|
||||
99
login_timeout_test.go
Normal file
99
login_timeout_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||
info := normalizeLoginInput(LoginInput{})
|
||||
if info.Port != defaultSSHPort {
|
||||
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||
}
|
||||
if info.Timeout != 0 {
|
||||
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||
}
|
||||
if info.DialTimeout != 0 {
|
||||
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||
}
|
||||
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveDialTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info LoginInput
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "default fallback",
|
||||
info: LoginInput{},
|
||||
want: defaultLoginTimeout,
|
||||
},
|
||||
{
|
||||
name: "reuse timeout when dial timeout omitted",
|
||||
info: LoginInput{Timeout: 9 * time.Second},
|
||||
want: 9 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "explicit dial timeout wins",
|
||||
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||
want: 3 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "negative dial timeout disables default dial deadline",
|
||||
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
|
||||
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||
t.Cleanup(func() {
|
||||
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||
})
|
||||
|
||||
captured := time.Duration(-2)
|
||||
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||||
captured = timeout
|
||||
return ssh.Password("agent"), nil, nil
|
||||
}
|
||||
|
||||
info := LoginInput{
|
||||
Timeout: 0,
|
||||
DialTimeout: 11 * time.Second,
|
||||
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||
}
|
||||
auth, cleanup, err := buildAuthMethods(info)
|
||||
if err != nil {
|
||||
t.Fatalf("buildAuthMethods: %v", err)
|
||||
}
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if len(auth) != 1 {
|
||||
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||
}
|
||||
if captured != 11*time.Second {
|
||||
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
|
||||
}
|
||||
}
|
||||
@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: info.Timeout,
|
||||
Timeout: effectiveDialTimeout(info),
|
||||
}
|
||||
return dialer.DialContext
|
||||
}
|
||||
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
|
||||
}
|
||||
|
||||
dialContext := resolveDialContext(info)
|
||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
||||
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||
if proxyConfig != nil {
|
||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||
}
|
||||
|
||||
22
types.go
22
types.go
@ -94,14 +94,20 @@ type LoginInput struct {
|
||||
AuthOrder []AuthMethodKind
|
||||
Addr string
|
||||
Port int
|
||||
Timeout time.Duration
|
||||
DialContext DialContextFunc
|
||||
Proxy *ProxyConfig
|
||||
Jump *LoginInput
|
||||
KeepAliveInterval time.Duration
|
||||
KeepAliveTimeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||
// already been established. Zero means no authentication timeout.
|
||||
Timeout time.Duration
|
||||
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||
DialTimeout time.Duration
|
||||
DialContext DialContextFunc
|
||||
Proxy *ProxyConfig
|
||||
Jump *LoginInput
|
||||
KeepAliveInterval time.Duration
|
||||
KeepAliveTimeout time.Duration
|
||||
HostKeyCallback func(string, net.Addr, ssh.PublicKey) error
|
||||
BannerCallback func(string) error
|
||||
}
|
||||
|
||||
// StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user