starssh/login_timeout_test.go
starainrt 0c23e7d4bf
feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
2026-05-27 13:10:35 +08:00

764 lines
22 KiB
Go

package starssh
import (
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"errors"
"io"
"net"
"os"
"sync"
"testing"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
info := normalizeLoginInput(LoginInput{})
if info.Port != defaultSSHPort {
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
}
if info.Timeout != 0 {
t.Fatalf("Timeout=%v want 0", info.Timeout)
}
if info.DialTimeout != 0 {
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
}
if info.SSHAgentTimeout != 0 {
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
}
if info.SSHAgentForwardTimeout != 0 {
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
}
}
func TestEffectiveLoginTimeout(t *testing.T) {
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero login timeout should stay zero, got %v", got)
}
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
t.Fatalf("expected explicit login timeout, got %v", got)
}
}
func TestEffectiveDialTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback",
info: LoginInput{},
want: defaultLoginTimeout,
},
{
name: "reuse timeout when dial timeout omitted",
info: LoginInput{Timeout: 9 * time.Second},
want: 9 * time.Second,
},
{
name: "explicit dial timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
want: 3 * time.Second,
},
{
name: "negative dial timeout disables default dial deadline",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveDialTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestEffectiveSSHAgentTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback without auth timeout",
info: LoginInput{},
want: defaultSSHAgentTimeout,
},
{
name: "auth timeout does not cap default",
info: LoginInput{Timeout: 9 * time.Second},
want: defaultSSHAgentTimeout,
},
{
name: "explicit agent timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
want: 90 * time.Second,
},
{
name: "negative agent timeout disables operation deadline",
info: LoginInput{SSHAgentTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero forward timeout should stay zero, got %v", got)
}
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
t.Fatalf("expected explicit forward timeout, got %v", got)
}
}
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
captured = timeouts
return ssh.Password("agent"), nil, nil
}
info := LoginInput{
Timeout: 0,
DialTimeout: 11 * time.Second,
SSHAgentTimeout: 90 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
IdentityAgent: "/tmp/custom-agent.sock",
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}
auth, cleanup, err := buildAuthMethods(info)
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("expected one auth method, got %d", len(auth))
}
if captured.Dial != 11*time.Second {
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
}
if captured.Operation != 90*time.Second {
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
}
if captured.Forward != 4*time.Second {
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
}
if captured.Endpoint != "/tmp/custom-agent.sock" {
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
}
}
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethods(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
})
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("auth methods=%d, want 1", len(auth))
}
}
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
}
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
}
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
}
}
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
if timeouts.SignFailure == nil {
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
}
if timeouts.SkipFingerprints != nil {
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
}
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
Password: "secret",
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
}, newSSHAgentAuthAttempt())
if err != nil {
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 2 {
t.Fatalf("auth methods=%d want 2", len(auth))
}
}
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
attempt := newSSHAgentAuthAttempt()
attempt.skipFingerprint("SHA256:test")
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
if err := checkSSHAgentRetryPending(attempt); err != nil {
t.Fatalf("retry should reset on next attempt: %v", err)
}
}
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
signer := mustGenerateTestSigner(t)
attempt := newSSHAgentAuthAttempt()
callback := privateKeySignersCallback(signer, attempt)
signers, err := callback()
if err != nil {
t.Fatalf("private key callback before retry: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback returned %#v, want original signer", signers)
}
attempt.skipFingerprint("SHA256:test")
signers, err = callback()
if !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
}
if signers != nil {
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
}
attempt.begin()
signers, err = callback()
if err != nil {
t.Fatalf("private key callback after retry reset: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
}
}
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
attempt := newSSHAgentAuthAttempt()
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(firstMethods) != 2 {
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
}
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
}
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(secondMethods) != 1 {
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
}
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second auth method did not skip failed first key")
}
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
if err != nil {
t.Fatalf("second signer Sign: %v", err)
}
if signature == nil {
t.Fatal("second signer returned nil signature")
}
}
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
wantErr := errors.New("first agent key cannot sign")
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
var buildCalls int
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
buildCalls++
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
if buildCalls == 1 {
if len(filteredSigners) != 2 {
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
if len(filteredSigners) != 1 {
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
}
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatal("retry build did not skip failed signer")
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
attempt := newSSHAgentAuthAttempt()
attempt.begin()
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("first auth methods=%d want 1", len(auth))
}
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
t.Fatalf("raw failing signer err=%v", err)
}
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
})[0]
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("retry auth methods=%d want 1", len(auth))
}
if buildCalls != 2 {
t.Fatalf("build calls=%d want 2", buildCalls)
}
}
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
t.Fatalf("first signer changed order without hints")
}
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second signer changed order without hints")
}
}
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
signer := mustGenerateTestSigner(t)
wantErr := errors.New("agent refused operation")
var debugCalls int
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
Resolved: resolvedSSHAgentEndpoint{
Endpoint: "/tmp/debug-agent.sock",
Source: "identity-agent",
Network: "unix",
},
Debug: func(event SSHAgentDebugEvent) {
debugCalls++
if event.Step != "auth" || event.Phase != "sign" {
t.Fatalf("unexpected debug event: %+v", event)
}
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
t.Fatalf("unexpected endpoint details: %+v", event)
}
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
t.Fatalf("unexpected sign status: %+v", event)
}
},
})
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
if !errors.Is(err, wantErr) {
t.Fatalf("Sign err=%v want original signer error", err)
}
if debugCalls != 1 {
t.Fatalf("debug calls=%d want 1", debugCalls)
}
}
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
spy := &testAlgorithmSpySigner{Signer: signer}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
}
if signature.Format != ssh.KeyAlgoRSASHA256 {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
}
}
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
if err != nil {
t.Fatalf("NewSignerWithAlgorithms: %v", err)
}
spy := &testMultiAlgorithmSpySigner{
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
}
if signature.Format != ssh.KeyAlgoRSA {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
}
}
type deadlineSpyConn struct {
net.Conn
mu sync.Mutex
deadlines []time.Time
readErr error
writeErr error
}
type testFailingSigner struct {
ssh.Signer
err error
}
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
return nil, s.err
}
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
return nil, s.err
}
type testAlgorithmSpySigner struct {
ssh.Signer
lastAlgorithm string
}
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
s.lastAlgorithm = algorithm
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
}
type testMultiAlgorithmSpySigner struct {
*testAlgorithmSpySigner
}
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
return multiAlgorithmSigner.Algorithms()
}
return nil
}
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
t.Helper()
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new test signer: %v", err)
}
return signer
}
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
t.Helper()
baseSigner := mustGenerateTestSigner(t)
publicKey := baseSigner.PublicKey()
return &commentedTestSigner{
Signer: baseSigner,
publicKey: &sshagent.Key{
Format: publicKey.Type(),
Blob: publicKey.Marshal(),
Comment: comment,
},
}
}
type commentedTestSigner struct {
ssh.Signer
publicKey ssh.PublicKey
}
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate rsa test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new rsa test signer: %v", err)
}
return signer
}
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.deadlines = append(c.deadlines, deadline)
return nil
}
func (c *deadlineSpyConn) deadlineCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.deadlines)
}
func (c *deadlineSpyConn) firstDeadline() time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.deadlines[0]
}
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
if c.readErr != nil {
return 0, c.readErr
}
return 0, nil
}
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
if c.writeErr != nil {
return 0, c.writeErr
}
return len(p), nil
}
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
spy := &deadlineSpyConn{readErr: io.EOF}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
buf := make([]byte, 1)
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
t.Fatalf("Read err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
t.Fatalf("deadline=%v should be in the future", firstDeadline)
}
}
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
spy := &deadlineSpyConn{}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
if _, err := conn.Write([]byte("x")); err != nil {
t.Fatalf("Write err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
}
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/identity-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "identity-agent" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/env-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "SSH_AUTH_SOCK" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return client, nil
}
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
Operation: 20 * time.Millisecond,
Endpoint: "/tmp/hung-agent.sock",
})
if cleanup != nil {
cleanup()
}
if !errors.Is(err, ErrSSHAgentTimeout) {
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
}
}
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
socketPath := tempUnixSocketPath(t)
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen unix: %v", err)
}
defer listener.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := listener.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
var events []SSHAgentDebugEvent
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
Dial: time.Second,
Operation: time.Second,
Endpoint: socketPath,
Debug: func(event SSHAgentDebugEvent) {
events = append(events, event)
},
})
<-done
if len(events) == 0 {
t.Fatal("expected debug events")
}
if events[0].Step != "auth" || events[0].Phase != "dial" {
t.Fatalf("unexpected first event: %+v", events[0])
}
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
t.Fatalf("unexpected endpoint event: %+v", events[0])
}
}
func tempUnixSocketPath(t *testing.T) string {
t.Helper()
path := t.TempDir() + "/agent.sock"
t.Cleanup(func() {
_ = os.Remove(path)
})
return path
}