fix: 拆分 starssh 的拨号超时与认证超时语义
- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】 - 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值 - 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑 - 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时 - 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout - 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题 - 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout - 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
This commit is contained in:
parent
b29246a9c4
commit
1625997d8f
@ -80,7 +80,7 @@ func (s *StarSSH) ensureAgentForwarding() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
|
keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapSSHAgentForwardingUnavailable(err)
|
return wrapSSHAgentForwardingUnavailable(err)
|
||||||
}
|
}
|
||||||
|
|||||||
34
login.go
34
login.go
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
|
||||||
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
|
||||||
|
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
|
||||||
|
|
||||||
var defaultAuthOrder = []AuthMethodKind{
|
var defaultAuthOrder = []AuthMethodKind{
|
||||||
AuthMethodSSHAgent,
|
AuthMethodSSHAgent,
|
||||||
@ -42,7 +43,8 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
return nil, ErrHostKeyCallbackRequired
|
return nil, ErrHostKeyCallbackRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
|
authTimeout := effectiveLoginTimeout(info)
|
||||||
|
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshInfo := &StarSSH{
|
sshInfo := &StarSSH{
|
||||||
@ -76,7 +78,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
clientConfig := &ssh.ClientConfig{
|
clientConfig := &ssh.ClientConfig{
|
||||||
User: info.User,
|
User: info.User,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Timeout: info.Timeout,
|
Timeout: authTimeout,
|
||||||
HostKeyCallback: hostKeyCallback,
|
HostKeyCallback: hostKeyCallback,
|
||||||
BannerCallback: bannerCallback,
|
BannerCallback: bannerCallback,
|
||||||
}
|
}
|
||||||
@ -93,7 +95,7 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return sshInfo, err
|
return sshInfo, err
|
||||||
}
|
}
|
||||||
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
|
restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
|
||||||
defer restoreDeadline()
|
defer restoreDeadline()
|
||||||
|
|
||||||
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
|
||||||
@ -130,6 +132,7 @@ func LoginSimple(host string, user string, passwd string, prikeyPath string, por
|
|||||||
Addr: host,
|
Addr: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
|
DialTimeout: timeout,
|
||||||
User: user,
|
User: user,
|
||||||
HostKeyCallback: DefaultAllowHostKeyCallback,
|
HostKeyCallback: DefaultAllowHostKeyCallback,
|
||||||
}
|
}
|
||||||
@ -154,12 +157,29 @@ func normalizeLoginInput(info LoginInput) LoginInput {
|
|||||||
if info.Port <= 0 {
|
if info.Port <= 0 {
|
||||||
info.Port = defaultSSHPort
|
info.Port = defaultSSHPort
|
||||||
}
|
}
|
||||||
if info.Timeout <= 0 {
|
|
||||||
info.Timeout = defaultLoginTimeout
|
|
||||||
}
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func effectiveLoginTimeout(info LoginInput) time.Duration {
|
||||||
|
if info.Timeout <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return info.Timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveDialTimeout(info LoginInput) time.Duration {
|
||||||
|
switch {
|
||||||
|
case info.DialTimeout < 0:
|
||||||
|
return 0
|
||||||
|
case info.DialTimeout > 0:
|
||||||
|
return info.DialTimeout
|
||||||
|
case info.Timeout > 0:
|
||||||
|
return info.Timeout
|
||||||
|
default:
|
||||||
|
return defaultLoginTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
||||||
order, err := normalizeAuthOrder(info.AuthOrder)
|
order, err := normalizeAuthOrder(info.AuthOrder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -194,7 +214,7 @@ func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
|
|||||||
if info.DisableSSHAgent {
|
if info.DisableSSHAgent {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
|
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
agentErr = err
|
agentErr = err
|
||||||
continue
|
continue
|
||||||
|
|||||||
99
login_timeout_test.go
Normal file
99
login_timeout_test.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package starssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
||||||
|
info := normalizeLoginInput(LoginInput{})
|
||||||
|
if info.Port != defaultSSHPort {
|
||||||
|
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
||||||
|
}
|
||||||
|
if info.Timeout != 0 {
|
||||||
|
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
||||||
|
}
|
||||||
|
if info.DialTimeout != 0 {
|
||||||
|
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveLoginTimeout(t *testing.T) {
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
||||||
|
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
||||||
|
}
|
||||||
|
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
||||||
|
t.Fatalf("expected explicit login timeout, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveDialTimeout(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
info LoginInput
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default fallback",
|
||||||
|
info: LoginInput{},
|
||||||
|
want: defaultLoginTimeout,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reuse timeout when dial timeout omitted",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second},
|
||||||
|
want: 9 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit dial timeout wins",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
||||||
|
want: 3 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative dial timeout disables default dial deadline",
|
||||||
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
||||||
|
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
|
||||||
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
||||||
|
t.Cleanup(func() {
|
||||||
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
||||||
|
})
|
||||||
|
|
||||||
|
captured := time.Duration(-2)
|
||||||
|
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
|
||||||
|
captured = timeout
|
||||||
|
return ssh.Password("agent"), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info := LoginInput{
|
||||||
|
Timeout: 0,
|
||||||
|
DialTimeout: 11 * time.Second,
|
||||||
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
||||||
|
}
|
||||||
|
auth, cleanup, err := buildAuthMethods(info)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildAuthMethods: %v", err)
|
||||||
|
}
|
||||||
|
if cleanup != nil {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
if len(auth) != 1 {
|
||||||
|
t.Fatalf("expected one auth method, got %d", len(auth))
|
||||||
|
}
|
||||||
|
if captured != 11*time.Second {
|
||||||
|
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -32,7 +32,7 @@ func resolveDialContext(info LoginInput) DialContextFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: info.Timeout,
|
Timeout: effectiveDialTimeout(info),
|
||||||
}
|
}
|
||||||
return dialer.DialContext
|
return dialer.DialContext
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialContext := resolveDialContext(info)
|
dialContext := resolveDialContext(info)
|
||||||
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info))
|
||||||
if proxyConfig != nil {
|
if proxyConfig != nil {
|
||||||
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||||||
}
|
}
|
||||||
|
|||||||
6
types.go
6
types.go
@ -94,7 +94,13 @@ type LoginInput struct {
|
|||||||
AuthOrder []AuthMethodKind
|
AuthOrder []AuthMethodKind
|
||||||
Addr string
|
Addr string
|
||||||
Port int
|
Port int
|
||||||
|
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
|
||||||
|
// already been established. Zero means no authentication timeout.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
// DialTimeout limits outbound dial steps such as TCP connect, proxy connect, and
|
||||||
|
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
|
||||||
|
// uses the package default dial timeout. Negative disables the default dial timeout.
|
||||||
|
DialTimeout time.Duration
|
||||||
DialContext DialContextFunc
|
DialContext DialContextFunc
|
||||||
Proxy *ProxyConfig
|
Proxy *ProxyConfig
|
||||||
Jump *LoginInput
|
Jump *LoginInput
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user