package starssh import ( "testing" "time" "golang.org/x/crypto/ssh" ) func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) { info := normalizeLoginInput(LoginInput{}) if info.Port != defaultSSHPort { t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort) } if info.Timeout != 0 { t.Fatalf("Timeout=%v want 0", info.Timeout) } if info.DialTimeout != 0 { t.Fatalf("DialTimeout=%v want 0", info.DialTimeout) } } func TestEffectiveLoginTimeout(t *testing.T) { if got := effectiveLoginTimeout(LoginInput{}); got != 0 { t.Fatalf("zero login timeout should stay zero, got %v", got) } if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second { t.Fatalf("expected explicit login timeout, got %v", got) } } func TestEffectiveDialTimeout(t *testing.T) { tests := []struct { name string info LoginInput want time.Duration }{ { name: "default fallback", info: LoginInput{}, want: defaultLoginTimeout, }, { name: "reuse timeout when dial timeout omitted", info: LoginInput{Timeout: 9 * time.Second}, want: 9 * time.Second, }, { name: "explicit dial timeout wins", info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second}, want: 3 * time.Second, }, { name: "negative dial timeout disables default dial deadline", info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1}, want: 0, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if got := effectiveDialTimeout(tc.info); got != tc.want { t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want) } }) } } func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) { oldBuilder := buildSSHAgentAuthMethodFunc t.Cleanup(func() { buildSSHAgentAuthMethodFunc = oldBuilder }) captured := time.Duration(-2) buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) { captured = timeout return ssh.Password("agent"), nil, nil } info := LoginInput{ Timeout: 0, DialTimeout: 11 * time.Second, AuthOrder: []AuthMethodKind{AuthMethodSSHAgent}, } auth, cleanup, err := buildAuthMethods(info) if err != nil { t.Fatalf("buildAuthMethods: %v", err) } if cleanup != nil { cleanup() } if len(auth) != 1 { t.Fatalf("expected one auth method, got %d", len(auth)) } if captured != 11*time.Second { t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second) } }