From 1625997d8f77ac18665ac739a264fe4731d900de Mon Sep 17 00:00:00 2001 From: starainrt Date: Sun, 26 Apr 2026 23:29:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8B=86=E5=88=86=20starssh=20=E7=9A=84?= =?UTF-8?q?=E6=8B=A8=E5=8F=B7=E8=B6=85=E6=97=B6=E4=B8=8E=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E8=B6=85=E6=97=B6=E8=AF=AD=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】 - 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值 - 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑 - 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时 - 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout - 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题 - 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout - 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试 --- agent_forward.go | 2 +- login.go | 34 ++++++++++++--- login_timeout_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++ transport.go | 4 +- types.go | 22 ++++++---- 5 files changed, 143 insertions(+), 18 deletions(-) create mode 100644 login_timeout_test.go diff --git a/agent_forward.go b/agent_forward.go index 14a76e3..dc0069f 100644 --- a/agent_forward.go +++ b/agent_forward.go @@ -80,7 +80,7 @@ func (s *StarSSH) ensureAgentForwarding() error { return err } - keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout) + keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo)) if err != nil { return wrapSSHAgentForwardingUnavailable(err) } diff --git a/login.go b/login.go index 948d62e..19f739f 100644 --- a/login.go +++ b/login.go @@ -16,6 +16,7 @@ import ( var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key") var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") +var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod var defaultAuthOrder = []AuthMethodKind{ AuthMethodSSHAgent, @@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { return nil, ErrHostKeyCallbackRequired } - loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout) + authTimeout := effectiveLoginTimeout(info) + loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout) defer cancel() sshInfo := &StarSSH{ @@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { clientConfig := &ssh.ClientConfig{ User: info.User, Auth: auth, - Timeout: info.Timeout, + Timeout: authTimeout, HostKeyCallback: hostKeyCallback, BannerCallback: bannerCallback, } @@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) { if err != nil { return sshInfo, err } - restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout) + restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout) defer restoreDeadline() clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig) @@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por Addr: host, Port: port, Timeout: timeout, + DialTimeout: timeout, User: user, HostKeyCallback: DefaultAllowHostKeyCallback, } @@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput { if info.Port <= 0 { info.Port = defaultSSHPort } - if info.Timeout <= 0 { - info.Timeout = defaultLoginTimeout - } return info } +func effectiveLoginTimeout(info LoginInput) time.Duration { + if info.Timeout <= 0 { + return 0 + } + return info.Timeout +} + +func effectiveDialTimeout(info LoginInput) time.Duration { + switch { + case info.DialTimeout < 0: + return 0 + case info.DialTimeout > 0: + return info.DialTimeout + case info.Timeout > 0: + return info.Timeout + default: + return defaultLoginTimeout + } +} + func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { order, err := normalizeAuthOrder(info.AuthOrder) if err != nil { @@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { if info.DisableSSHAgent { continue } - agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout) + agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info)) if err != nil { agentErr = err continue diff --git a/login_timeout_test.go b/login_timeout_test.go new file mode 100644 index 0000000..3f407fa --- /dev/null +++ b/login_timeout_test.go @@ -0,0 +1,99 @@ +package starssh + +import ( + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) { + info := normalizeLoginInput(LoginInput{}) + if info.Port != defaultSSHPort { + t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort) + } + if info.Timeout != 0 { + t.Fatalf("Timeout=%v want 0", info.Timeout) + } + if info.DialTimeout != 0 { + t.Fatalf("DialTimeout=%v want 0", info.DialTimeout) + } +} + +func TestEffectiveLoginTimeout(t *testing.T) { + if got := effectiveLoginTimeout(LoginInput{}); got != 0 { + t.Fatalf("zero login timeout should stay zero, got %v", got) + } + if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second { + t.Fatalf("expected explicit login timeout, got %v", got) + } +} + +func TestEffectiveDialTimeout(t *testing.T) { + tests := []struct { + name string + info LoginInput + want time.Duration + }{ + { + name: "default fallback", + info: LoginInput{}, + want: defaultLoginTimeout, + }, + { + name: "reuse timeout when dial timeout omitted", + info: LoginInput{Timeout: 9 * time.Second}, + want: 9 * time.Second, + }, + { + name: "explicit dial timeout wins", + info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second}, + want: 3 * time.Second, + }, + { + name: "negative dial timeout disables default dial deadline", + info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1}, + want: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := effectiveDialTimeout(tc.info); got != tc.want { + t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want) + } + }) + } +} + +func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) { + oldBuilder := buildSSHAgentAuthMethodFunc + t.Cleanup(func() { + buildSSHAgentAuthMethodFunc = oldBuilder + }) + + captured := time.Duration(-2) + buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) { + captured = timeout + return ssh.Password("agent"), nil, nil + } + + info := LoginInput{ + Timeout: 0, + DialTimeout: 11 * time.Second, + AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, + } + auth, cleanup, err := buildAuthMethods(info) + if err != nil { + t.Fatalf("buildAuthMethods: %v", err) + } + if cleanup != nil { + cleanup() + } + if len(auth) != 1 { + t.Fatalf("expected one auth method, got %d", len(auth)) + } + if captured != 11*time.Second { + t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second) + } +} diff --git a/transport.go b/transport.go index 92be42b..a2a11a8 100644 --- a/transport.go +++ b/transport.go @@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc { } dialer := &net.Dialer{ - Timeout: info.Timeout, + Timeout: effectiveDialTimeout(info), } return dialer.DialContext } @@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e } dialContext := resolveDialContext(info) - proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout) + proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info)) if proxyConfig != nil { return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr) } diff --git a/types.go b/types.go index 27d9028..acac5a1 100644 --- a/types.go +++ b/types.go @@ -94,14 +94,20 @@ type LoginInput struct { AuthOrder []AuthMethodKind Addr string Port int - Timeout time.Duration - DialContext DialContextFunc - Proxy *ProxyConfig - Jump *LoginInput - KeepAliveInterval time.Duration - KeepAliveTimeout time.Duration - HostKeyCallback func(string, net.Addr, ssh.PublicKey) error - BannerCallback func(string) error + // Timeout limits the SSH handshake/authentication phase after a TCP connection has + // already been established. Zero means no authentication timeout. + Timeout time.Duration + // DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and + // local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise + // uses the package default dial timeout. Negative disables the default dial timeout. + DialTimeout time.Duration + DialContext DialContextFunc + Proxy *ProxyConfig + Jump *LoginInput + KeepAliveInterval time.Duration + KeepAliveTimeout time.Duration + HostKeyCallback func(string, net.Addr, ssh.PublicKey) error + BannerCallback func(string) error } // StarShell keeps the legacy prompt-driven helper for POSIX-style scripted shell interactions.