- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑 - 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件 - 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录 - 支持 agent signer 失败后跳过坏 key 并重试后续 key - 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求 - 增强 agent forwarding 的探测、通道空闲超时和关闭清理 - 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持 - 增加相关回归测试和 Windows 编译验证覆盖
764 lines
22 KiB
Go
764 lines
22 KiB
Go
package starssh
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
sshagent "golang.org/x/crypto/ssh/agent"
|
|
)
|
|
|
|
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
|
|
info := normalizeLoginInput(LoginInput{})
|
|
if info.Port != defaultSSHPort {
|
|
t.Fatalf("Port=%d want %d", info.Port, defaultSSHPort)
|
|
}
|
|
if info.Timeout != 0 {
|
|
t.Fatalf("Timeout=%v want 0", info.Timeout)
|
|
}
|
|
if info.DialTimeout != 0 {
|
|
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
|
|
}
|
|
if info.SSHAgentTimeout != 0 {
|
|
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
|
|
}
|
|
if info.SSHAgentForwardTimeout != 0 {
|
|
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
|
|
}
|
|
}
|
|
|
|
func TestEffectiveLoginTimeout(t *testing.T) {
|
|
if got := effectiveLoginTimeout(LoginInput{}); got != 0 {
|
|
t.Fatalf("zero login timeout should stay zero, got %v", got)
|
|
}
|
|
if got := effectiveLoginTimeout(LoginInput{Timeout: 7 * time.Second}); got != 7*time.Second {
|
|
t.Fatalf("expected explicit login timeout, got %v", got)
|
|
}
|
|
}
|
|
|
|
func TestEffectiveDialTimeout(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
info LoginInput
|
|
want time.Duration
|
|
}{
|
|
{
|
|
name: "default fallback",
|
|
info: LoginInput{},
|
|
want: defaultLoginTimeout,
|
|
},
|
|
{
|
|
name: "reuse timeout when dial timeout omitted",
|
|
info: LoginInput{Timeout: 9 * time.Second},
|
|
want: 9 * time.Second,
|
|
},
|
|
{
|
|
name: "explicit dial timeout wins",
|
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second},
|
|
want: 3 * time.Second,
|
|
},
|
|
{
|
|
name: "negative dial timeout disables default dial deadline",
|
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: -1},
|
|
want: 0,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if got := effectiveDialTimeout(tc.info); got != tc.want {
|
|
t.Fatalf("effectiveDialTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEffectiveSSHAgentTimeout(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
info LoginInput
|
|
want time.Duration
|
|
}{
|
|
{
|
|
name: "default fallback without auth timeout",
|
|
info: LoginInput{},
|
|
want: defaultSSHAgentTimeout,
|
|
},
|
|
{
|
|
name: "auth timeout does not cap default",
|
|
info: LoginInput{Timeout: 9 * time.Second},
|
|
want: defaultSSHAgentTimeout,
|
|
},
|
|
{
|
|
name: "explicit agent timeout wins",
|
|
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
|
|
want: 90 * time.Second,
|
|
},
|
|
{
|
|
name: "negative agent timeout disables operation deadline",
|
|
info: LoginInput{SSHAgentTimeout: -1},
|
|
want: 0,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
|
|
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
|
|
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
|
|
t.Fatalf("zero forward timeout should stay zero, got %v", got)
|
|
}
|
|
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
|
|
t.Fatalf("expected explicit forward timeout, got %v", got)
|
|
}
|
|
}
|
|
|
|
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
|
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
|
t.Cleanup(func() {
|
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
|
})
|
|
|
|
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
|
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
|
captured = timeouts
|
|
return ssh.Password("agent"), nil, nil
|
|
}
|
|
|
|
info := LoginInput{
|
|
Timeout: 0,
|
|
DialTimeout: 11 * time.Second,
|
|
SSHAgentTimeout: 90 * time.Second,
|
|
SSHAgentForwardTimeout: 4 * time.Second,
|
|
IdentityAgent: "/tmp/custom-agent.sock",
|
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
|
}
|
|
auth, cleanup, err := buildAuthMethods(info)
|
|
if err != nil {
|
|
t.Fatalf("buildAuthMethods: %v", err)
|
|
}
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if len(auth) != 1 {
|
|
t.Fatalf("expected one auth method, got %d", len(auth))
|
|
}
|
|
if captured.Dial != 11*time.Second {
|
|
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
|
|
}
|
|
if captured.Operation != 90*time.Second {
|
|
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
|
|
}
|
|
if captured.Forward != 4*time.Second {
|
|
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
|
|
}
|
|
if captured.Endpoint != "/tmp/custom-agent.sock" {
|
|
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
|
|
}
|
|
}
|
|
|
|
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
|
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
|
t.Cleanup(func() {
|
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
|
})
|
|
|
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
|
return ssh.Password("agent"), nil, nil
|
|
}
|
|
|
|
auth, cleanup, err := buildAuthMethods(LoginInput{
|
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("buildAuthMethods: %v", err)
|
|
}
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if len(auth) != 1 {
|
|
t.Fatalf("auth methods=%d, want 1", len(auth))
|
|
}
|
|
}
|
|
|
|
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
|
|
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
|
|
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
|
|
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
|
|
}
|
|
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
|
|
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
|
|
}
|
|
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
|
|
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
|
|
}
|
|
}
|
|
|
|
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
|
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
|
t.Cleanup(func() {
|
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
|
})
|
|
|
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
|
if timeouts.SignFailure == nil {
|
|
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
|
|
}
|
|
if timeouts.SkipFingerprints != nil {
|
|
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
|
|
}
|
|
return ssh.Password("agent"), nil, nil
|
|
}
|
|
|
|
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
|
Password: "secret",
|
|
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
|
|
}, newSSHAgentAuthAttempt())
|
|
if err != nil {
|
|
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
|
|
}
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if len(auth) != 2 {
|
|
t.Fatalf("auth methods=%d want 2", len(auth))
|
|
}
|
|
}
|
|
|
|
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
|
|
attempt := newSSHAgentAuthAttempt()
|
|
attempt.skipFingerprint("SHA256:test")
|
|
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
|
|
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
|
|
}
|
|
attempt.begin()
|
|
if err := checkSSHAgentRetryPending(attempt); err != nil {
|
|
t.Fatalf("retry should reset on next attempt: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
|
|
signer := mustGenerateTestSigner(t)
|
|
attempt := newSSHAgentAuthAttempt()
|
|
callback := privateKeySignersCallback(signer, attempt)
|
|
|
|
signers, err := callback()
|
|
if err != nil {
|
|
t.Fatalf("private key callback before retry: %v", err)
|
|
}
|
|
if len(signers) != 1 || signers[0] != signer {
|
|
t.Fatalf("private key callback returned %#v, want original signer", signers)
|
|
}
|
|
|
|
attempt.skipFingerprint("SHA256:test")
|
|
signers, err = callback()
|
|
if !errors.Is(err, errRetrySSHAgentAuth) {
|
|
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
|
|
}
|
|
if signers != nil {
|
|
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
|
|
}
|
|
|
|
attempt.begin()
|
|
signers, err = callback()
|
|
if err != nil {
|
|
t.Fatalf("private key callback after retry reset: %v", err)
|
|
}
|
|
if len(signers) != 1 || signers[0] != signer {
|
|
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
|
|
}
|
|
}
|
|
|
|
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
|
|
firstSigner := mustGenerateTestSigner(t)
|
|
secondSigner := mustGenerateTestSigner(t)
|
|
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
|
|
|
|
attempt := newSSHAgentAuthAttempt()
|
|
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
|
SignFailure: attempt.recordSignFailure,
|
|
SkipFingerprints: attempt.skipSnapshot(),
|
|
})
|
|
if len(firstMethods) != 2 {
|
|
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
|
|
}
|
|
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
|
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
|
|
}
|
|
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
|
|
SignFailure: attempt.recordSignFailure,
|
|
SkipFingerprints: attempt.skipSnapshot(),
|
|
})
|
|
if len(secondMethods) != 1 {
|
|
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
|
|
}
|
|
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
|
t.Fatalf("second auth method did not skip failed first key")
|
|
}
|
|
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
|
|
if err != nil {
|
|
t.Fatalf("second signer Sign: %v", err)
|
|
}
|
|
if signature == nil {
|
|
t.Fatal("second signer returned nil signature")
|
|
}
|
|
}
|
|
|
|
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
|
|
firstSigner := mustGenerateTestSigner(t)
|
|
secondSigner := mustGenerateTestSigner(t)
|
|
wantErr := errors.New("first agent key cannot sign")
|
|
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
|
|
|
|
oldBuilder := buildSSHAgentAuthMethodFunc
|
|
t.Cleanup(func() {
|
|
buildSSHAgentAuthMethodFunc = oldBuilder
|
|
})
|
|
|
|
var buildCalls int
|
|
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
|
|
buildCalls++
|
|
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
|
|
if buildCalls == 1 {
|
|
if len(filteredSigners) != 2 {
|
|
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
|
|
}
|
|
return ssh.PublicKeys(filteredSigners...), nil, nil
|
|
}
|
|
if len(filteredSigners) != 1 {
|
|
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
|
|
}
|
|
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
|
t.Fatal("retry build did not skip failed signer")
|
|
}
|
|
return ssh.PublicKeys(filteredSigners...), nil, nil
|
|
}
|
|
|
|
attempt := newSSHAgentAuthAttempt()
|
|
attempt.begin()
|
|
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
|
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
|
}, attempt)
|
|
if err != nil {
|
|
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
|
|
}
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if len(auth) != 1 {
|
|
t.Fatalf("first auth methods=%d want 1", len(auth))
|
|
}
|
|
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
|
|
t.Fatalf("raw failing signer err=%v", err)
|
|
}
|
|
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
|
|
SignFailure: attempt.recordSignFailure,
|
|
})[0]
|
|
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
|
|
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
|
|
}
|
|
|
|
attempt.begin()
|
|
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
|
|
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
|
|
}, attempt)
|
|
if err != nil {
|
|
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
|
|
}
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if len(auth) != 1 {
|
|
t.Fatalf("retry auth methods=%d want 1", len(auth))
|
|
}
|
|
if buildCalls != 2 {
|
|
t.Fatalf("build calls=%d want 2", buildCalls)
|
|
}
|
|
}
|
|
|
|
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
|
|
plainSigner := mustGenerateTestSigner(t)
|
|
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
|
|
|
|
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
|
|
if len(ordered) != 2 {
|
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
|
}
|
|
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
|
|
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
|
}
|
|
}
|
|
|
|
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
|
|
plainSigner := mustGenerateTestSigner(t)
|
|
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
|
|
|
|
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
|
|
if len(ordered) != 2 {
|
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
|
}
|
|
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
|
|
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
|
|
}
|
|
}
|
|
|
|
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
|
|
firstSigner := mustGenerateTestSigner(t)
|
|
secondSigner := mustGenerateTestSigner(t)
|
|
|
|
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
|
|
if len(ordered) != 2 {
|
|
t.Fatalf("ordered signers=%d want 2", len(ordered))
|
|
}
|
|
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
|
|
t.Fatalf("first signer changed order without hints")
|
|
}
|
|
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
|
|
t.Fatalf("second signer changed order without hints")
|
|
}
|
|
}
|
|
|
|
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
|
|
signer := mustGenerateTestSigner(t)
|
|
wantErr := errors.New("agent refused operation")
|
|
var debugCalls int
|
|
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
|
|
Resolved: resolvedSSHAgentEndpoint{
|
|
Endpoint: "/tmp/debug-agent.sock",
|
|
Source: "identity-agent",
|
|
Network: "unix",
|
|
},
|
|
Debug: func(event SSHAgentDebugEvent) {
|
|
debugCalls++
|
|
if event.Step != "auth" || event.Phase != "sign" {
|
|
t.Fatalf("unexpected debug event: %+v", event)
|
|
}
|
|
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
|
|
t.Fatalf("unexpected endpoint details: %+v", event)
|
|
}
|
|
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
|
|
t.Fatalf("unexpected sign status: %+v", event)
|
|
}
|
|
},
|
|
})
|
|
|
|
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
|
|
if !errors.Is(err, wantErr) {
|
|
t.Fatalf("Sign err=%v want original signer error", err)
|
|
}
|
|
if debugCalls != 1 {
|
|
t.Fatalf("debug calls=%d want 1", debugCalls)
|
|
}
|
|
}
|
|
|
|
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
|
|
signer := mustGenerateRSATestSigner(t)
|
|
spy := &testAlgorithmSpySigner{Signer: signer}
|
|
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
|
if !ok {
|
|
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
|
}
|
|
|
|
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
|
if err != nil {
|
|
t.Fatalf("SignWithAlgorithm: %v", err)
|
|
}
|
|
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
|
|
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
|
|
}
|
|
if signature.Format != ssh.KeyAlgoRSASHA256 {
|
|
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
|
|
}
|
|
}
|
|
|
|
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
|
|
signer := mustGenerateRSATestSigner(t)
|
|
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
|
|
if err != nil {
|
|
t.Fatalf("NewSignerWithAlgorithms: %v", err)
|
|
}
|
|
spy := &testMultiAlgorithmSpySigner{
|
|
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
|
|
}
|
|
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
|
|
if !ok {
|
|
t.Fatal("wrapped signer does not implement AlgorithmSigner")
|
|
}
|
|
|
|
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
|
|
if err != nil {
|
|
t.Fatalf("SignWithAlgorithm: %v", err)
|
|
}
|
|
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
|
|
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
|
|
}
|
|
if signature.Format != ssh.KeyAlgoRSA {
|
|
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
|
|
}
|
|
}
|
|
|
|
type deadlineSpyConn struct {
|
|
net.Conn
|
|
mu sync.Mutex
|
|
deadlines []time.Time
|
|
readErr error
|
|
writeErr error
|
|
}
|
|
|
|
type testFailingSigner struct {
|
|
ssh.Signer
|
|
err error
|
|
}
|
|
|
|
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
|
return nil, s.err
|
|
}
|
|
|
|
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
|
return nil, s.err
|
|
}
|
|
|
|
type testAlgorithmSpySigner struct {
|
|
ssh.Signer
|
|
lastAlgorithm string
|
|
}
|
|
|
|
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
|
s.lastAlgorithm = algorithm
|
|
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
|
|
}
|
|
|
|
type testMultiAlgorithmSpySigner struct {
|
|
*testAlgorithmSpySigner
|
|
}
|
|
|
|
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
|
|
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
|
|
return multiAlgorithmSigner.Algorithms()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
|
|
t.Helper()
|
|
_, key, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("generate test private key: %v", err)
|
|
}
|
|
signer, err := ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
t.Fatalf("new test signer: %v", err)
|
|
}
|
|
return signer
|
|
}
|
|
|
|
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
|
|
t.Helper()
|
|
baseSigner := mustGenerateTestSigner(t)
|
|
publicKey := baseSigner.PublicKey()
|
|
return &commentedTestSigner{
|
|
Signer: baseSigner,
|
|
publicKey: &sshagent.Key{
|
|
Format: publicKey.Type(),
|
|
Blob: publicKey.Marshal(),
|
|
Comment: comment,
|
|
},
|
|
}
|
|
}
|
|
|
|
type commentedTestSigner struct {
|
|
ssh.Signer
|
|
publicKey ssh.PublicKey
|
|
}
|
|
|
|
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
|
|
return s.publicKey
|
|
}
|
|
|
|
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
|
|
t.Helper()
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("generate rsa test private key: %v", err)
|
|
}
|
|
signer, err := ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
t.Fatalf("new rsa test signer: %v", err)
|
|
}
|
|
return signer
|
|
}
|
|
|
|
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
c.deadlines = append(c.deadlines, deadline)
|
|
return nil
|
|
}
|
|
|
|
func (c *deadlineSpyConn) deadlineCount() int {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return len(c.deadlines)
|
|
}
|
|
|
|
func (c *deadlineSpyConn) firstDeadline() time.Time {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.deadlines[0]
|
|
}
|
|
|
|
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
|
|
if c.readErr != nil {
|
|
return 0, c.readErr
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
|
|
if c.writeErr != nil {
|
|
return 0, c.writeErr
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
|
|
spy := &deadlineSpyConn{readErr: io.EOF}
|
|
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
|
buf := make([]byte, 1)
|
|
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
|
|
t.Fatalf("Read err=%v", err)
|
|
}
|
|
if spy.deadlineCount() != 1 {
|
|
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
|
}
|
|
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
|
|
t.Fatalf("deadline=%v should be in the future", firstDeadline)
|
|
}
|
|
}
|
|
|
|
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
|
|
spy := &deadlineSpyConn{}
|
|
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
|
|
if _, err := conn.Write([]byte("x")); err != nil {
|
|
t.Fatalf("Write err=%v", err)
|
|
}
|
|
if spy.deadlineCount() != 1 {
|
|
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
|
|
}
|
|
}
|
|
|
|
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
|
|
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
|
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
|
|
if err != nil {
|
|
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
|
}
|
|
if resolved.Endpoint != "/tmp/identity-agent.sock" {
|
|
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
|
}
|
|
if resolved.Source != "identity-agent" {
|
|
t.Fatalf("source=%q", resolved.Source)
|
|
}
|
|
}
|
|
|
|
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
|
|
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
|
|
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
|
|
if err != nil {
|
|
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
|
|
}
|
|
if resolved.Endpoint != "/tmp/env-agent.sock" {
|
|
t.Fatalf("endpoint=%q", resolved.Endpoint)
|
|
}
|
|
if resolved.Source != "SSH_AUTH_SOCK" {
|
|
t.Fatalf("source=%q", resolved.Source)
|
|
}
|
|
}
|
|
|
|
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
|
|
server, client := net.Pipe()
|
|
defer server.Close()
|
|
|
|
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
|
|
t.Cleanup(func() {
|
|
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
|
|
})
|
|
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
|
|
return client, nil
|
|
}
|
|
|
|
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
|
|
Operation: 20 * time.Millisecond,
|
|
Endpoint: "/tmp/hung-agent.sock",
|
|
})
|
|
if cleanup != nil {
|
|
cleanup()
|
|
}
|
|
if !errors.Is(err, ErrSSHAgentTimeout) {
|
|
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
|
|
}
|
|
}
|
|
|
|
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
|
|
socketPath := tempUnixSocketPath(t)
|
|
listener, err := net.Listen("unix", socketPath)
|
|
if err != nil {
|
|
t.Fatalf("listen unix: %v", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
var events []SSHAgentDebugEvent
|
|
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
|
|
Dial: time.Second,
|
|
Operation: time.Second,
|
|
Endpoint: socketPath,
|
|
Debug: func(event SSHAgentDebugEvent) {
|
|
events = append(events, event)
|
|
},
|
|
})
|
|
<-done
|
|
|
|
if len(events) == 0 {
|
|
t.Fatal("expected debug events")
|
|
}
|
|
if events[0].Step != "auth" || events[0].Phase != "dial" {
|
|
t.Fatalf("unexpected first event: %+v", events[0])
|
|
}
|
|
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
|
|
t.Fatalf("unexpected endpoint event: %+v", events[0])
|
|
}
|
|
}
|
|
|
|
func tempUnixSocketPath(t *testing.T) string {
|
|
t.Helper()
|
|
path := t.TempDir() + "/agent.sock"
|
|
t.Cleanup(func() {
|
|
_ = os.Remove(path)
|
|
})
|
|
return path
|
|
}
|