feat: 增强 ssh-agent 认证与转发可靠性

- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
This commit is contained in:
兔子 2026-05-27 13:10:35 +08:00
parent ad7c8b0587
commit 0c23e7d4bf
Signed by: b612
GPG Key ID: 99DD2222B612B612
10 changed files with 2173 additions and 294 deletions

View File

@ -19,12 +19,12 @@ var requestSSHAgentForwarding = func(session *ssh.Session) error {
const sshAgentChannelType = "auth-agent@openssh.com"
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return startSSHAgentForwardProxy(client, timeout)
var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return startSSHAgentForwardProxy(client, timeouts)
}
var probeSSHAgentForwarding = func(timeout time.Duration) error {
conn, err := dialSSHAgent(timeout)
var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts)
if err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
@ -60,8 +60,12 @@ type sshAgentForwardBridge struct {
proxy *sshAgentForwardProxy
channel ssh.Channel
conn net.Conn
idleTimeout time.Duration
closeOnce sync.Once
signalOnce sync.Once
done chan struct{}
activity chan struct{}
}
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
@ -111,14 +115,14 @@ func (s *StarSSH) ensureAgentForwarding() error {
return err
}
timeout := effectiveDialTimeout(s.LoginInfo)
if err := probeSSHAgentForwarding(timeout); err != nil {
timeouts := effectiveSSHAgentTimeouts(s.LoginInfo)
if err := probeSSHAgentForwarding(timeouts); err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if s.closing.Load() {
return errSSHClientClosing
}
closer, err := routeSSHAgentForwarding(client, timeout)
closer, err := routeSSHAgentForwarding(client, timeouts)
if err != nil {
return err
}
@ -182,7 +186,7 @@ func wrapSSHAgentForwardingUnavailable(err error) error {
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
}
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
@ -204,18 +208,18 @@ func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Cl
if !ok {
return
}
go handleSSHAgentForwardChannel(proxy, ch, timeout)
go handleSSHAgentForwardChannel(proxy, ch, timeouts)
}
}
}()
return proxy, nil
}
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) {
if ch == nil {
return
}
conn, err := dialSSHAgent(timeout)
conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts)
if err != nil {
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
return
@ -224,7 +228,6 @@ func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
return
}
channel, reqs, err := ch.Accept()
if err != nil {
_ = conn.Close()
@ -236,6 +239,7 @@ func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel
proxy: proxy,
channel: channel,
conn: conn,
idleTimeout: timeouts.Forward,
}
if !proxy.registerBridge(bridge) {
bridge.close()
@ -256,18 +260,27 @@ func (b *sshAgentForwardBridge) run() {
if b == nil {
return
}
b.ensureSignals()
stopWatchdog := b.startIdleWatchdog()
defer stopWatchdog()
defer b.unregister()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(b.channel, b.conn)
_, _ = io.Copy(
sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch},
sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch},
)
b.close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(b.conn, b.channel)
_, _ = io.Copy(
sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch},
sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch},
)
b.close()
}()
wg.Wait()
@ -278,6 +291,8 @@ func (b *sshAgentForwardBridge) close() {
return
}
b.closeOnce.Do(func() {
b.ensureSignals()
close(b.done)
closeWriter(b.channel)
closeWriter(b.conn)
if b.channel != nil {
@ -289,6 +304,90 @@ func (b *sshAgentForwardBridge) close() {
})
}
func (b *sshAgentForwardBridge) ensureSignals() {
if b == nil {
return
}
b.signalOnce.Do(func() {
b.done = make(chan struct{})
b.activity = make(chan struct{}, 1)
})
}
func (b *sshAgentForwardBridge) startIdleWatchdog() func() {
if b == nil || b.idleTimeout <= 0 {
return func() {}
}
b.ensureSignals()
timer := time.NewTimer(b.idleTimeout)
stopped := make(chan struct{})
go func() {
defer timer.Stop()
for {
select {
case <-timer.C:
b.close()
return
case <-b.activity:
resetTimer(timer, b.idleTimeout)
case <-b.done:
return
case <-stopped:
return
}
}
}()
return func() {
close(stopped)
}
}
func (b *sshAgentForwardBridge) touch() {
if b == nil || b.idleTimeout <= 0 || b.activity == nil {
return
}
select {
case b.activity <- struct{}{}:
default:
}
}
type sshAgentForwardActivityReader struct {
io.Reader
touch func()
}
func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) {
n, err := r.Reader.Read(p)
if n > 0 && r.touch != nil {
r.touch()
}
return n, err
}
type sshAgentForwardActivityWriter struct {
io.Writer
touch func()
}
func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) {
n, err := w.Writer.Write(p)
if n > 0 && w.touch != nil {
w.touch()
}
return n, err
}
func resetTimer(timer *time.Timer, timeout time.Duration) {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(timeout)
}
func (b *sshAgentForwardBridge) unregister() {
if b == nil || b.proxy == nil {
return

View File

@ -44,6 +44,32 @@ type testSSHChannel struct {
closeCh chan struct{}
}
type testNewChannel struct {
channel ssh.Channel
accepted atomic.Bool
rejected atomic.Bool
}
func (c *testNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
c.accepted.Store(true)
requests := make(chan *ssh.Request)
close(requests)
return c.channel, requests, nil
}
func (c *testNewChannel) Reject(reason ssh.RejectionReason, message string) error {
c.rejected.Store(true)
return nil
}
func (c *testNewChannel) ChannelType() string {
return sshAgentChannelType
}
func (c *testNewChannel) ExtraData() []byte {
return nil
}
func newTestSSHChannel(readFunc func([]byte) (int, error)) *testSSHChannel {
return &testSSHChannel{
readFunc: readFunc,
@ -116,6 +142,8 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
LoginInfo: LoginInput{
ForwardSSHAgent: true,
Timeout: time.Second,
SSHAgentTimeout: 3 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
},
}
star.setTransport(baseClient, nil)
@ -129,22 +157,34 @@ func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
var probeCalls atomic.Int32
closer := &testCloser{}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
probeCalls.Add(1)
if timeout != time.Second {
t.Fatalf("unexpected forwarding timeout: %v", timeout)
if timeouts.Dial != time.Second {
t.Fatalf("unexpected forwarding dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected forwarding operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected forwarding idle timeout: %v", timeouts.Forward)
}
return nil
}
var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
routeCalls.Add(1)
if client != baseClient {
t.Fatalf("unexpected routed client %p", client)
}
if timeout != time.Second {
t.Fatalf("unexpected routed timeout: %v", timeout)
if timeouts.Dial != time.Second {
t.Fatalf("unexpected routed dial timeout: %v", timeouts.Dial)
}
if timeouts.Operation != 3*time.Second {
t.Fatalf("unexpected routed operation timeout: %v", timeouts.Operation)
}
if timeouts.Forward != 4*time.Second {
t.Fatalf("unexpected routed idle timeout: %v", timeouts.Forward)
}
return closer, nil
}
@ -215,10 +255,10 @@ func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
return nil
}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
@ -255,7 +295,7 @@ func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
t.Fatal("agent forwarding probe should not run when disabled")
return nil
}
@ -280,7 +320,7 @@ func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
@ -303,7 +343,7 @@ func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
}
@ -326,10 +366,10 @@ func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
@ -364,10 +404,10 @@ func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
return &testCloser{}, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
@ -397,7 +437,7 @@ func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
@ -424,7 +464,7 @@ func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
}
@ -453,10 +493,10 @@ func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
started := make(chan struct{})
release := make(chan struct{})
closer := &testCloser{}
probeSSHAgentForwarding = func(timeout time.Duration) error {
probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error {
return nil
}
routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) {
close(started)
<-release
return closer, nil
@ -570,3 +610,56 @@ func TestSSHAgentForwardProxyCloseClosesActiveBridges(t *testing.T) {
t.Fatal("expected proxy close to close ssh channel")
}
}
func TestHandleSSHAgentForwardChannelUsesForwardTimeout(t *testing.T) {
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
agentConn, peerConn := net.Pipe()
defer peerConn.Close()
tracked := &trackedConn{Conn: agentConn}
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return tracked, nil
}
channel := newBlockingTestSSHChannel()
newChannel := &testNewChannel{
channel: channel,
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
handleSSHAgentForwardChannel(proxy, newChannel, sshAgentTimeouts{
Endpoint: "/tmp/agent.sock",
Forward: 20 * time.Millisecond,
})
if !newChannel.accepted.Load() {
t.Fatal("expected channel to be accepted")
}
waitUntil(t, time.Second, func() bool {
return tracked.closed.Load() > 0 && channel.closed.Load() > 0
}, "forwarded agent bridge did not close both sides after idle timeout")
waitUntil(t, time.Second, func() bool {
proxy.activeMu.Lock()
defer proxy.activeMu.Unlock()
return len(proxy.active) == 0
}, "forwarded agent bridge did not unregister after idle timeout")
}
func waitUntil(t *testing.T, timeout time.Duration, condition func() bool, message string) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(time.Millisecond)
}
t.Fatal(message)
}

243
login.go
View File

@ -4,26 +4,14 @@ import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
@ -47,11 +35,35 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
defer cancel()
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, err
}
if shouldRetrySSHAgentAuth(info, order) {
agentAttempt := newSSHAgentAuthAttempt()
for {
agentAttempt.begin()
sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt)
if err == nil {
return sshInfo, nil
}
if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil {
continue
}
return sshInfo, err
}
}
return loginOnceWithContext(loginCtx, info, authTimeout, nil)
}
func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) {
sshInfo := &StarSSH{
LoginInfo: info,
}
auth, authCleanup, err := buildAuthMethods(info)
auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt)
if err != nil {
return nil, err
}
@ -91,11 +103,11 @@ func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
}
targetAddr := joinHostPort(info.Addr, info.Port)
rawConn, upstream, err := dialTargetConn(loginCtx, info)
rawConn, upstream, err := dialTargetConn(ctx, info)
if err != nil {
return sshInfo, err
}
restoreDeadline := applyConnDeadline(rawConn, loginCtx, authTimeout)
restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout)
defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
@ -179,204 +191,3 @@ func effectiveDialTimeout(info LoginInput) time.Duration {
return defaultLoginTimeout
}
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(effectiveDialTimeout(info))
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
if password != "" {
return ssh.Password(password)
}
if callback != nil {
return ssh.PasswordCallback(callback)
}
return nil
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(challenge)
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
conn, err := dialSSHAgent(timeout)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
signers, err := agent.NewClient(conn).Signers()
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
return ssh.PublicKeys(signers...), func() {
_ = conn.Close()
}, nil
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}

View File

@ -1,10 +1,19 @@
package starssh
import (
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"errors"
"io"
"net"
"os"
"sync"
"testing"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
@ -18,6 +27,12 @@ func TestNormalizeLoginInputKeepsZeroAuthTimeout(t *testing.T) {
if info.DialTimeout != 0 {
t.Fatalf("DialTimeout=%v want 0", info.DialTimeout)
}
if info.SSHAgentTimeout != 0 {
t.Fatalf("SSHAgentTimeout=%v want 0", info.SSHAgentTimeout)
}
if info.SSHAgentForwardTimeout != 0 {
t.Fatalf("SSHAgentForwardTimeout=%v want 0", info.SSHAgentForwardTimeout)
}
}
func TestEffectiveLoginTimeout(t *testing.T) {
@ -66,21 +81,70 @@ func TestEffectiveDialTimeout(t *testing.T) {
}
}
func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
func TestEffectiveSSHAgentTimeout(t *testing.T) {
tests := []struct {
name string
info LoginInput
want time.Duration
}{
{
name: "default fallback without auth timeout",
info: LoginInput{},
want: defaultSSHAgentTimeout,
},
{
name: "auth timeout does not cap default",
info: LoginInput{Timeout: 9 * time.Second},
want: defaultSSHAgentTimeout,
},
{
name: "explicit agent timeout wins",
info: LoginInput{Timeout: 9 * time.Second, DialTimeout: 3 * time.Second, SSHAgentTimeout: 90 * time.Second},
want: 90 * time.Second,
},
{
name: "negative agent timeout disables operation deadline",
info: LoginInput{SSHAgentTimeout: -1},
want: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := effectiveSSHAgentTimeout(tc.info); got != tc.want {
t.Fatalf("effectiveSSHAgentTimeout(%+v)=%v want %v", tc.info, got, tc.want)
}
})
}
}
func TestEffectiveSSHAgentForwardTimeout(t *testing.T) {
if got := effectiveSSHAgentForwardTimeout(LoginInput{}); got != 0 {
t.Fatalf("zero forward timeout should stay zero, got %v", got)
}
if got := effectiveSSHAgentForwardTimeout(LoginInput{SSHAgentForwardTimeout: 4 * time.Second}); got != 4*time.Second {
t.Fatalf("expected explicit forward timeout, got %v", got)
}
}
func TestBuildAuthMethodsUsesSeparateSSHAgentTimeouts(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
captured := time.Duration(-2)
buildSSHAgentAuthMethodFunc = func(timeout time.Duration) (ssh.AuthMethod, func(), error) {
captured = timeout
captured := sshAgentTimeouts{Dial: -2, Operation: -2, Forward: -2}
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
captured = timeouts
return ssh.Password("agent"), nil, nil
}
info := LoginInput{
Timeout: 0,
DialTimeout: 11 * time.Second,
SSHAgentTimeout: 90 * time.Second,
SSHAgentForwardTimeout: 4 * time.Second,
IdentityAgent: "/tmp/custom-agent.sock",
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}
auth, cleanup, err := buildAuthMethods(info)
@ -93,7 +157,607 @@ func TestBuildAuthMethodsUsesDialTimeoutInsteadOfAuthTimeout(t *testing.T) {
if len(auth) != 1 {
t.Fatalf("expected one auth method, got %d", len(auth))
}
if captured != 11*time.Second {
t.Fatalf("agent auth builder timeout=%v want %v", captured, 11*time.Second)
if captured.Dial != 11*time.Second {
t.Fatalf("agent auth builder dial timeout=%v want %v", captured.Dial, 11*time.Second)
}
if captured.Operation != 90*time.Second {
t.Fatalf("agent auth builder operation timeout=%v want %v", captured.Operation, 90*time.Second)
}
if captured.Forward != 4*time.Second {
t.Fatalf("agent auth builder forward timeout=%v want %v", captured.Forward, 4*time.Second)
}
if captured.Endpoint != "/tmp/custom-agent.sock" {
t.Fatalf("agent auth builder endpoint=%q want custom endpoint", captured.Endpoint)
}
}
func TestBuildAuthMethodsUsesSingleAgentAuthMethod(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethods(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
})
if err != nil {
t.Fatalf("buildAuthMethods: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("auth methods=%d, want 1", len(auth))
}
}
func TestShouldRetrySSHAgentAuthWhenAgentIsNotFirst(t *testing.T) {
order := []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent}
if !shouldRetrySSHAgentAuth(LoginInput{}, order) {
t.Fatal("expected ssh-agent retry when ssh-agent is present after password")
}
if shouldRetrySSHAgentAuth(LoginInput{DisableSSHAgent: true}, order) {
t.Fatal("expected ssh-agent retry disabled when DisableSSHAgent is true")
}
if shouldRetrySSHAgentAuth(LoginInput{}, []AuthMethodKind{AuthMethodPassword}) {
t.Fatal("expected no ssh-agent retry when ssh-agent auth is absent")
}
}
func TestBuildAuthMethodsWithAgentAttemptMarksNonFirstAgentForRetry(t *testing.T) {
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
if timeouts.SignFailure == nil {
t.Fatal("expected SignFailure callback for non-first ssh-agent auth")
}
if timeouts.SkipFingerprints != nil {
t.Fatalf("unexpected initial skip fingerprints: %#v", timeouts.SkipFingerprints)
}
return ssh.Password("agent"), nil, nil
}
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
Password: "secret",
AuthOrder: []AuthMethodKind{AuthMethodPassword, AuthMethodSSHAgent},
}, newSSHAgentAuthAttempt())
if err != nil {
t.Fatalf("buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 2 {
t.Fatalf("auth methods=%d want 2", len(auth))
}
}
func TestAgentRetryPendingBlocksFallbackAuthThenResets(t *testing.T) {
attempt := newSSHAgentAuthAttempt()
attempt.skipFingerprint("SHA256:test")
if err := checkSSHAgentRetryPending(attempt); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("retry pending err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
if err := checkSSHAgentRetryPending(attempt); err != nil {
t.Fatalf("retry should reset on next attempt: %v", err)
}
}
func TestAgentRetryPendingBlocksPrivateKeyAuth(t *testing.T) {
signer := mustGenerateTestSigner(t)
attempt := newSSHAgentAuthAttempt()
callback := privateKeySignersCallback(signer, attempt)
signers, err := callback()
if err != nil {
t.Fatalf("private key callback before retry: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback returned %#v, want original signer", signers)
}
attempt.skipFingerprint("SHA256:test")
signers, err = callback()
if !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("private key callback err=%v want errRetrySSHAgentAuth", err)
}
if signers != nil {
t.Fatalf("private key callback signers=%#v want nil while retry pending", signers)
}
attempt.begin()
signers, err = callback()
if err != nil {
t.Fatalf("private key callback after retry reset: %v", err)
}
if len(signers) != 1 || signers[0] != signer {
t.Fatalf("private key callback after retry returned %#v, want original signer", signers)
}
}
func TestFilterSSHAgentSignersSkipsSignerAfterSignFailure(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: errors.New("first agent key cannot sign")}
attempt := newSSHAgentAuthAttempt()
firstMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(firstMethods) != 2 {
t.Fatalf("first auth method signers=%d want 2", len(firstMethods))
}
if _, err := firstMethods[0].Sign(nil, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("first signer err=%v want errRetrySSHAgentAuth", err)
}
secondMethods := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
SkipFingerprints: attempt.skipSnapshot(),
})
if len(secondMethods) != 1 {
t.Fatalf("second auth method signers=%d want 1", len(secondMethods))
}
if string(secondMethods[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second auth method did not skip failed first key")
}
signature, err := secondMethods[0].Sign(nil, []byte("challenge"))
if err != nil {
t.Fatalf("second signer Sign: %v", err)
}
if signature == nil {
t.Fatal("second signer returned nil signature")
}
}
func TestBuildAuthMethodsSkipsFailedAgentSignerOnRetry(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
wantErr := errors.New("first agent key cannot sign")
failingFirstSigner := &testFailingSigner{Signer: firstSigner, err: wantErr}
oldBuilder := buildSSHAgentAuthMethodFunc
t.Cleanup(func() {
buildSSHAgentAuthMethodFunc = oldBuilder
})
var buildCalls int
buildSSHAgentAuthMethodFunc = func(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
buildCalls++
filteredSigners := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner, secondSigner}, timeouts)
if buildCalls == 1 {
if len(filteredSigners) != 2 {
t.Fatalf("first build signers=%d want 2", len(filteredSigners))
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
if len(filteredSigners) != 1 {
t.Fatalf("retry build signers=%d want 1", len(filteredSigners))
}
if string(filteredSigners[0].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatal("retry build did not skip failed signer")
}
return ssh.PublicKeys(filteredSigners...), nil, nil
}
attempt := newSSHAgentAuthAttempt()
attempt.begin()
auth, cleanup, err := buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("first buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("first auth methods=%d want 1", len(auth))
}
if _, err := failingFirstSigner.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, wantErr) {
t.Fatalf("raw failing signer err=%v", err)
}
firstWrapped := filterSSHAgentSignersForRetry([]ssh.Signer{failingFirstSigner}, sshAgentTimeouts{
SignFailure: attempt.recordSignFailure,
})[0]
if _, err := firstWrapped.Sign(rand.Reader, []byte("challenge")); !errors.Is(err, errRetrySSHAgentAuth) {
t.Fatalf("wrapped failing signer err=%v want errRetrySSHAgentAuth", err)
}
attempt.begin()
auth, cleanup, err = buildAuthMethodsWithAgentAttempt(LoginInput{
AuthOrder: []AuthMethodKind{AuthMethodSSHAgent},
}, attempt)
if err != nil {
t.Fatalf("retry buildAuthMethodsWithAgentAttempt: %v", err)
}
if cleanup != nil {
cleanup()
}
if len(auth) != 1 {
t.Fatalf("retry auth methods=%d want 1", len(auth))
}
if buildCalls != 2 {
t.Fatalf("build calls=%d want 2", buildCalls)
}
}
func TestOrderSSHAgentSignersPrefersPriorityComment(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
prioritySigner := mustGenerateCommentedTestSigner(t, "priority=40")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, prioritySigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(prioritySigner.PublicKey().Marshal()) {
t.Fatalf("priority signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersPrefersCardKeys(t *testing.T) {
plainSigner := mustGenerateTestSigner(t)
cardSigner := mustGenerateCommentedTestSigner(t, "cardno:26_865_673")
ordered := orderSSHAgentSigners([]ssh.Signer{plainSigner, cardSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(cardSigner.PublicKey().Marshal()) {
t.Fatalf("card signer should be first, got %s", sshAgentSignerComment(ordered[0]))
}
}
func TestOrderSSHAgentSignersKeepsStableOrderWithoutHints(t *testing.T) {
firstSigner := mustGenerateTestSigner(t)
secondSigner := mustGenerateTestSigner(t)
ordered := orderSSHAgentSigners([]ssh.Signer{firstSigner, secondSigner})
if len(ordered) != 2 {
t.Fatalf("ordered signers=%d want 2", len(ordered))
}
if string(ordered[0].PublicKey().Marshal()) != string(firstSigner.PublicKey().Marshal()) {
t.Fatalf("first signer changed order without hints")
}
if string(ordered[1].PublicKey().Marshal()) != string(secondSigner.PublicKey().Marshal()) {
t.Fatalf("second signer changed order without hints")
}
}
func TestSSHAgentSignerEmitsSignDebugWithoutChangingError(t *testing.T) {
signer := mustGenerateTestSigner(t)
wantErr := errors.New("agent refused operation")
var debugCalls int
wrapped := wrapSSHAgentSigner(&testFailingSigner{Signer: signer, err: wantErr}, sshAgentSignerOptions{
Resolved: resolvedSSHAgentEndpoint{
Endpoint: "/tmp/debug-agent.sock",
Source: "identity-agent",
Network: "unix",
},
Debug: func(event SSHAgentDebugEvent) {
debugCalls++
if event.Step != "auth" || event.Phase != "sign" {
t.Fatalf("unexpected debug event: %+v", event)
}
if event.Endpoint != "/tmp/debug-agent.sock" || event.Source != "identity-agent" || event.Network != "unix" {
t.Fatalf("unexpected endpoint details: %+v", event)
}
if event.Status != "error" || !errors.Is(event.Err, wantErr) {
t.Fatalf("unexpected sign status: %+v", event)
}
},
})
_, err := wrapped.Sign(rand.Reader, []byte("challenge"))
if !errors.Is(err, wantErr) {
t.Fatalf("Sign err=%v want original signer error", err)
}
if debugCalls != 1 {
t.Fatalf("debug calls=%d want 1", debugCalls)
}
}
func TestSSHAgentRetrySignerPrefersRSASHA2(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
spy := &testAlgorithmSpySigner{Signer: signer}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSASHA256 {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSASHA256)
}
if signature.Format != ssh.KeyAlgoRSASHA256 {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSASHA256)
}
}
func TestSSHAgentRetrySignerKeepsRestrictedRSA(t *testing.T) {
signer := mustGenerateRSATestSigner(t)
restricted, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSA})
if err != nil {
t.Fatalf("NewSignerWithAlgorithms: %v", err)
}
spy := &testMultiAlgorithmSpySigner{
testAlgorithmSpySigner: &testAlgorithmSpySigner{Signer: restricted},
}
wrapped, ok := wrapSSHAgentSignerForRetry(spy, func(ssh.PublicKey, error) {}).(ssh.AlgorithmSigner)
if !ok {
t.Fatal("wrapped signer does not implement AlgorithmSigner")
}
signature, err := wrapped.SignWithAlgorithm(rand.Reader, []byte("challenge"), ssh.KeyAlgoRSA)
if err != nil {
t.Fatalf("SignWithAlgorithm: %v", err)
}
if spy.lastAlgorithm != ssh.KeyAlgoRSA {
t.Fatalf("last algorithm=%q want %q", spy.lastAlgorithm, ssh.KeyAlgoRSA)
}
if signature.Format != ssh.KeyAlgoRSA {
t.Fatalf("signature format=%q want %q", signature.Format, ssh.KeyAlgoRSA)
}
}
type deadlineSpyConn struct {
net.Conn
mu sync.Mutex
deadlines []time.Time
readErr error
writeErr error
}
type testFailingSigner struct {
ssh.Signer
err error
}
func (s *testFailingSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
return nil, s.err
}
func (s *testFailingSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
return nil, s.err
}
type testAlgorithmSpySigner struct {
ssh.Signer
lastAlgorithm string
}
func (s *testAlgorithmSpySigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
s.lastAlgorithm = algorithm
return s.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand, data, algorithm)
}
type testMultiAlgorithmSpySigner struct {
*testAlgorithmSpySigner
}
func (s *testMultiAlgorithmSpySigner) Algorithms() []string {
if multiAlgorithmSigner, ok := s.Signer.(ssh.MultiAlgorithmSigner); ok {
return multiAlgorithmSigner.Algorithms()
}
return nil
}
func mustGenerateTestSigner(t *testing.T) ssh.Signer {
t.Helper()
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new test signer: %v", err)
}
return signer
}
func mustGenerateCommentedTestSigner(t *testing.T, comment string) ssh.Signer {
t.Helper()
baseSigner := mustGenerateTestSigner(t)
publicKey := baseSigner.PublicKey()
return &commentedTestSigner{
Signer: baseSigner,
publicKey: &sshagent.Key{
Format: publicKey.Type(),
Blob: publicKey.Marshal(),
Comment: comment,
},
}
}
type commentedTestSigner struct {
ssh.Signer
publicKey ssh.PublicKey
}
func (s *commentedTestSigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func mustGenerateRSATestSigner(t *testing.T) ssh.Signer {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate rsa test private key: %v", err)
}
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("new rsa test signer: %v", err)
}
return signer
}
func (c *deadlineSpyConn) SetDeadline(deadline time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.deadlines = append(c.deadlines, deadline)
return nil
}
func (c *deadlineSpyConn) deadlineCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.deadlines)
}
func (c *deadlineSpyConn) firstDeadline() time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.deadlines[0]
}
func (c *deadlineSpyConn) Read(p []byte) (int, error) {
if c.readErr != nil {
return 0, c.readErr
}
return 0, nil
}
func (c *deadlineSpyConn) Write(p []byte) (int, error) {
if c.writeErr != nil {
return 0, c.writeErr
}
return len(p), nil
}
func TestWrapSSHAgentConnWithDeadlineSetsReadDeadline(t *testing.T) {
spy := &deadlineSpyConn{readErr: io.EOF}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
buf := make([]byte, 1)
if _, err := conn.Read(buf); !errors.Is(err, io.EOF) {
t.Fatalf("Read err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
if firstDeadline := spy.firstDeadline(); time.Until(firstDeadline) <= 0 {
t.Fatalf("deadline=%v should be in the future", firstDeadline)
}
}
func TestWrapSSHAgentConnWithDeadlineSetsWriteDeadline(t *testing.T) {
spy := &deadlineSpyConn{}
conn := wrapSSHAgentConnWithDeadline(spy, 2*time.Second)
if _, err := conn.Write([]byte("x")); err != nil {
t.Fatalf("Write err=%v", err)
}
if spy.deadlineCount() != 1 {
t.Fatalf("deadlines=%d want 1", spy.deadlineCount())
}
}
func TestResolveSSHAgentEndpointUsesIdentityAgent(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{Endpoint: " /tmp/identity-agent.sock "})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/identity-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "identity-agent" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestResolveSSHAgentEndpointUsesSSHAuthSock(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")
resolved, err := resolveSSHAgentEndpoint(sshAgentDialOptions{})
if err != nil {
t.Fatalf("resolveSSHAgentEndpoint: %v", err)
}
if resolved.Endpoint != "/tmp/env-agent.sock" {
t.Fatalf("endpoint=%q", resolved.Endpoint)
}
if resolved.Source != "SSH_AUTH_SOCK" {
t.Fatalf("source=%q", resolved.Source)
}
}
func TestBuildSSHAgentAuthMethodTimesOutWhenAgentDoesNotRespond(t *testing.T) {
server, client := net.Pipe()
defer server.Close()
oldDialResolvedSSHAgent := dialResolvedSSHAgentFunc
t.Cleanup(func() {
dialResolvedSSHAgentFunc = oldDialResolvedSSHAgent
})
dialResolvedSSHAgentFunc = func(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
return client, nil
}
_, cleanup, err := buildSSHAgentAuthMethod(sshAgentTimeouts{
Operation: 20 * time.Millisecond,
Endpoint: "/tmp/hung-agent.sock",
})
if cleanup != nil {
cleanup()
}
if !errors.Is(err, ErrSSHAgentTimeout) {
t.Fatalf("err=%v want ErrSSHAgentTimeout", err)
}
}
func TestBuildSSHAgentAuthMethodEmitsDebugEvents(t *testing.T) {
socketPath := tempUnixSocketPath(t)
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen unix: %v", err)
}
defer listener.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := listener.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
var events []SSHAgentDebugEvent
_, _, _ = buildSSHAgentAuthMethod(sshAgentTimeouts{
Dial: time.Second,
Operation: time.Second,
Endpoint: socketPath,
Debug: func(event SSHAgentDebugEvent) {
events = append(events, event)
},
})
<-done
if len(events) == 0 {
t.Fatal("expected debug events")
}
if events[0].Step != "auth" || events[0].Phase != "dial" {
t.Fatalf("unexpected first event: %+v", events[0])
}
if events[0].Endpoint != socketPath || events[0].Source != "identity-agent" {
t.Fatalf("unexpected endpoint event: %+v", events[0])
}
}
func tempUnixSocketPath(t *testing.T) string {
t.Helper()
path := t.TempDir() + "/agent.sock"
t.Cleanup(func() {
_ = os.Remove(path)
})
return path
}

668
sshagent_auth.go Normal file
View File

@ -0,0 +1,668 @@
package starssh
import (
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
type sshAgentTimeouts struct {
Dial time.Duration
Operation time.Duration
Forward time.Duration
Endpoint string
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SkipFingerprints map[string]struct{}
SignFailure func(ssh.PublicKey, error)
}
type sshAgentAuthAttempt struct {
mu sync.Mutex
skipFingerprints map[string]struct{}
retryRequested bool
}
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func effectiveSSHAgentTimeout(info LoginInput) time.Duration {
switch {
case info.SSHAgentTimeout < 0:
return 0
case info.SSHAgentTimeout > 0:
return info.SSHAgentTimeout
default:
return defaultSSHAgentTimeout
}
}
func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts {
return sshAgentTimeouts{
Dial: effectiveDialTimeout(info),
Operation: effectiveSSHAgentTimeout(info),
Forward: effectiveSSHAgentForwardTimeout(info),
Endpoint: info.IdentityAgent,
Debug: info.SSHAgentDebug,
}
}
func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration {
if info.SSHAgentForwardTimeout > 0 {
return info.SSHAgentForwardTimeout
}
return 0
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
return buildAuthMethodsWithAgentAttempt(info, nil)
}
func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info, agentAttempt)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
timeouts := effectiveSSHAgentTimeouts(info)
if agentAttempt != nil {
timeouts.SkipFingerprints = agentAttempt.skipSnapshot()
timeouts.SignFailure = agentAttempt.recordSignFailure
}
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts)
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool {
if info.DisableSSHAgent {
return false
}
for _, methodKind := range order {
if methodKind == AuthMethodSSHAgent {
return true
}
}
return false
}
func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) {
return func() ([]ssh.Signer, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return []ssh.Signer{signer}, nil
}
}
func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod {
if password == "" && callback == nil {
return nil
}
return ssh.PasswordCallback(func() (string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return "", err
}
if password != "" {
return password, nil
}
return callback()
})
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
agentAttempt *sshAgentAuthAttempt,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return challenge(user, instruction, questions, echos)
})
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation)
started := time.Now()
signers, err := sshagent.NewClient(conn).Signers()
err = normalizeSSHAgentError(err)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "list",
Status: debugStatus(err),
Duration: time.Since(started),
KeyCount: len(signers),
Err: err,
})
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
timeouts.Resolved = resolved
orderedSigners := orderSSHAgentSigners(signers)
filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts)
if len(filteredSigners) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no usable keys")
}
return ssh.PublicKeys(filteredSigners...), func() {
_ = conn.Close()
}, nil
}
func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer {
type orderedSigner struct {
signer ssh.Signer
index int
score int
comment string
}
ordered := make([]orderedSigner, 0, len(signers))
for index, signer := range signers {
if signer == nil || signer.PublicKey() == nil {
continue
}
ordered = append(ordered, orderedSigner{
signer: signer,
index: index,
score: sshAgentSignerPriority(signer),
comment: sshAgentSignerComment(signer),
})
}
sort.SliceStable(ordered, func(i, j int) bool {
if ordered[i].score != ordered[j].score {
return ordered[i].score > ordered[j].score
}
return ordered[i].index < ordered[j].index
})
result := make([]ssh.Signer, 0, len(ordered))
for _, item := range ordered {
result = append(result, item.signer)
}
return result
}
func sshAgentSignerComment(signer ssh.Signer) string {
if signer == nil {
return ""
}
if key, ok := signer.PublicKey().(*sshagent.Key); ok {
return key.Comment
}
return ""
}
func sshAgentSignerPriority(signer ssh.Signer) int {
comment := strings.TrimSpace(sshAgentSignerComment(signer))
if comment == "" {
return 0
}
score := 0
if priority, ok := parseSSHAgentSignerPriority(comment); ok {
score += 100000 + priority*1000
}
lower := strings.ToLower(comment)
if strings.Contains(lower, "current") {
score += 400
}
if strings.Contains(lower, "cardno:") {
score += 300
}
if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") {
score += 100
}
if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") {
score += 50
}
return score
}
func parseSSHAgentSignerPriority(comment string) (int, bool) {
lower := strings.ToLower(comment)
index := strings.Index(lower, "priority=")
if index < 0 {
return 0, false
}
value := strings.TrimSpace(comment[index+len("priority="):])
if value == "" {
return 0, false
}
end := 0
for end < len(value) {
ch := value[end]
if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') {
end++
continue
}
break
}
if end == 0 {
return 0, false
}
priority, err := strconv.Atoi(value[:end])
if err != nil {
return 0, false
}
return priority, true
}
func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer {
filteredSigners := make([]ssh.Signer, 0, len(signers))
for _, signer := range signers {
if signer == nil {
continue
}
publicKey := signer.PublicKey()
if publicKey == nil {
continue
}
if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip {
continue
}
if timeouts.SignFailure == nil && timeouts.Debug == nil {
filteredSigners = append(filteredSigners, signer)
continue
}
filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{
Resolved: timeouts.Resolved,
Debug: timeouts.Debug,
SignFailure: timeouts.SignFailure,
}))
}
return filteredSigners
}
func newSSHAgentAuthAttempt() *sshAgentAuthAttempt {
return &sshAgentAuthAttempt{
skipFingerprints: make(map[string]struct{}),
}
}
func (a *sshAgentAuthAttempt) begin() {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = false
}
func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} {
if a == nil {
return nil
}
a.mu.Lock()
defer a.mu.Unlock()
if len(a.skipFingerprints) == 0 {
return nil
}
snapshot := make(map[string]struct{}, len(a.skipFingerprints))
for fingerprint := range a.skipFingerprints {
snapshot[fingerprint] = struct{}{}
}
return snapshot
}
func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) {
_ = err
if a == nil || publicKey == nil {
return
}
a.skipFingerprint(ssh.FingerprintSHA256(publicKey))
}
func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = true
if fingerprint != "" {
a.skipFingerprints[fingerprint] = struct{}{}
}
}
func (a *sshAgentAuthAttempt) shouldRetry() bool {
if a == nil {
return false
}
a.mu.Lock()
defer a.mu.Unlock()
return a.retryRequested
}
func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error {
if agentAttempt != nil && agentAttempt.shouldRetry() {
return errRetrySSHAgentAuth
}
return nil
}
type sshAgentRetrySigner struct {
signer ssh.Signer
publicKey ssh.PublicKey
options sshAgentSignerOptions
}
type sshAgentRetryAlgorithmSigner struct {
sshAgentRetrySigner
algorithmSigner ssh.AlgorithmSigner
}
type sshAgentRetryMultiAlgorithmSigner struct {
sshAgentRetryAlgorithmSigner
multiAlgorithmSigner ssh.MultiAlgorithmSigner
}
type sshAgentSignerOptions struct {
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SignFailure func(ssh.PublicKey, error)
}
func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer {
return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure})
}
func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer {
publicKey := signer.PublicKey()
base := sshAgentRetrySigner{
signer: signer,
publicKey: publicKey,
options: options,
}
if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok {
return &sshAgentRetryMultiAlgorithmSigner{
sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: multiAlgorithmSigner,
},
multiAlgorithmSigner: multiAlgorithmSigner,
}
}
if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok {
return &sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: algorithmSigner,
}
}
return &base
}
func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
started := time.Now()
signature, err := s.signer.Sign(rand, data)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error {
err = normalizeSSHAgentError(err)
s.logSignDebug(started, err)
if err == nil {
return nil
}
if s.options.SignFailure != nil {
s.options.SignFailure(s.publicKey, err)
return wrapSSHAgentSignError(err)
}
return err
}
func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) {
if s == nil || s.options.Debug == nil {
return
}
logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: s.options.Resolved.Source,
Endpoint: s.options.Resolved.Endpoint,
Network: s.options.Resolved.Network,
Phase: "sign",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
}
func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil)
started := time.Now()
signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string {
return s.multiAlgorithmSigner.Algorithms()
}
func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms())
started := time.Now()
signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string {
if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA {
return requested
}
if len(algorithms) == 0 {
return ssh.KeyAlgoRSASHA256
}
for _, algorithm := range algorithms {
if algorithm == ssh.KeyAlgoRSA {
break
}
if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 {
return algorithm
}
}
return requested
}
func wrapSSHAgentSignError(err error) error {
if err == nil {
return nil
}
return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err))
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}

158
sshagent_conn.go Normal file
View File

@ -0,0 +1,158 @@
package starssh
import (
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
var ErrSSHAgentTimeout = errors.New("ssh-agent timeout")
var dialResolvedSSHAgentFunc = dialResolvedSSHAgent
type sshAgentDialOptions struct {
Endpoint string
Timeout time.Duration
}
type resolvedSSHAgentEndpoint struct {
Endpoint string
Source string
Network string
}
type deadlineAgentConn struct {
net.Conn
timeout time.Duration
}
func resolveSSHAgentEndpoint(options sshAgentDialOptions) (resolvedSSHAgentEndpoint, error) {
endpoint := strings.TrimSpace(options.Endpoint)
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
endpoint = strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if endpoint != "" {
return resolvedSSHAgentEndpoint{
Endpoint: endpoint,
Source: "SSH_AUTH_SOCK",
Network: defaultSSHAgentNetwork(endpoint),
}, nil
}
return defaultSSHAgentEndpoint()
}
func dialSSHAgent(options sshAgentDialOptions) (net.Conn, resolvedSSHAgentEndpoint, error) {
resolved, err := resolveSSHAgentEndpoint(options)
if err != nil {
return nil, resolvedSSHAgentEndpoint{}, err
}
conn, err := dialResolvedSSHAgentFunc(resolved, options.Timeout)
if isTimeoutError(err) {
err = fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
if err != nil {
return nil, resolved, err
}
return conn, resolved, nil
}
func dialSSHAgentWithDebug(step string, timeouts sshAgentTimeouts) (net.Conn, resolvedSSHAgentEndpoint, error) {
options := sshAgentDialOptions{
Endpoint: timeouts.Endpoint,
Timeout: timeouts.Dial,
}
started := time.Now()
conn, resolved, err := dialSSHAgent(options)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: step,
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "dial",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
return conn, resolved, err
}
func logSSHAgentDebug(debug SSHAgentDebugFunc, event SSHAgentDebugEvent) {
if debug == nil {
return
}
debug(event)
}
func debugStatus(err error) string {
if err != nil {
return "error"
}
return "ok"
}
func wrapSSHAgentConnWithDeadline(conn net.Conn, timeout time.Duration) net.Conn {
if conn == nil || timeout <= 0 {
return conn
}
return &deadlineAgentConn{Conn: conn, timeout: timeout}
}
func (c *deadlineAgentConn) Read(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Read(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) Write(p []byte) (int, error) {
c.setDeadline()
n, err := c.Conn.Write(p)
return n, wrapSSHAgentConnError(err)
}
func (c *deadlineAgentConn) setDeadline() {
if c == nil || c.timeout <= 0 || c.Conn == nil {
return
}
_ = c.Conn.SetDeadline(time.Now().Add(c.timeout))
}
func isTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, os.ErrDeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
func wrapSSHAgentConnError(err error) error {
if isTimeoutError(err) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}
func normalizeSSHAgentError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, ErrSSHAgentTimeout) {
return err
}
if strings.Contains(err.Error(), ErrSSHAgentTimeout.Error()) {
return fmt.Errorf("%w: %v", ErrSSHAgentTimeout, err)
}
return err
}

View File

@ -4,16 +4,19 @@ package starssh
import (
"net"
"os"
"strings"
"time"
)
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if agentSock == "" {
return nil, errSSHAgentUnavailable
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
return resolvedSSHAgentEndpoint{}, errSSHAgentUnavailable
}
func defaultSSHAgentNetwork(endpoint string) string {
return "unix"
}
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
agentSock := resolved.Endpoint
if timeout > 0 {
return net.DialTimeout("unix", agentSock, timeout)
}

View File

@ -3,10 +3,16 @@
package starssh
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
@ -16,22 +22,40 @@ import (
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
func dialSSHAgent(timeout time.Duration) (net.Conn, error) {
agentSock := strings.TrimSpace(os.Getenv("SSH_AUTH_SOCK"))
if agentSock != "" {
return dialWindowsSSHAgentEndpoint(agentSock, timeout)
}
return dialWindowsNamedPipe(defaultWindowsSSHAgentPipe, timeout, true)
var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file")
type gpgSocketInfo struct {
port uint16
nonce []byte
cygwin bool
}
func dialWindowsSSHAgentEndpoint(endpoint string, timeout time.Duration) (net.Conn, error) {
if pipePath, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
return dialWindowsNamedPipe(pipePath, timeout, false)
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
return resolvedSSHAgentEndpoint{
Endpoint: defaultWindowsSSHAgentPipe,
Source: "platform-default",
Network: "windows-pipe",
}, nil
}
if timeout > 0 {
return net.DialTimeout("unix", endpoint, timeout)
func defaultSSHAgentNetwork(endpoint string) string {
if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
return "windows-pipe"
}
return net.Dial("unix", endpoint)
if isAgentSSHSocketPath(endpoint) {
return "gpg-socket"
}
return "unix"
}
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok {
return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default")
}
if isAgentSSHSocketPath(resolved.Endpoint) {
return dialWindowsGPGSocketFile(resolved.Endpoint, timeout)
}
return dialWindowsUnixAgent(resolved.Endpoint, timeout)
}
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
@ -42,11 +66,7 @@ func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFo
}
defer cancel()
conn, err := winio.DialPipeContext(ctx, path)
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
return nil, errSSHAgentUnavailable
}
return conn, err
return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound)
}
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
@ -68,3 +88,184 @@ func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
func isWindowsPipeUnavailable(err error) bool {
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
}
func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) {
if timeout > 0 {
return net.DialTimeout("unix", endpoint, timeout)
}
return net.Dial("unix", endpoint)
}
func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0)
}
func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) {
if path == "" {
return nil, fmt.Errorf("gpg agent endpoint is empty")
}
if depth > 8 {
return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if target, ok := parseGPGAssuanSocketRedirect(data); ok {
target = resolveGPGSocketRedirectTarget(path, target)
if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok {
return dialWindowsNamedPipeContext(ctx, pipePath, false)
}
return dialWindowsGPGSocketFileDepth(ctx, target, depth+1)
}
info, err := parseGPGSocketInfo(path, data)
if err != nil {
return nil, err
}
return dialWindowsGPGSocketInfo(ctx, info)
}
func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) {
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port))))
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
_ = conn.Close()
return nil, err
}
}
if _, err := conn.Write(info.nonce); err != nil {
_ = conn.Close()
return nil, err
}
if info.cygwin {
var nonce [16]byte
if _, err := io.ReadFull(conn, nonce[:]); err != nil {
_ = conn.Close()
return nil, err
}
var credential [8]byte
binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid()))
if _, err := conn.Write(credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
if _, err := io.ReadFull(conn, credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
}
_ = conn.SetDeadline(time.Time{})
return conn, nil
}
func resolveGPGSocketRedirectTarget(source string, target string) string {
target = strings.TrimSpace(target)
if target == "" || filepath.IsAbs(target) {
return target
}
if _, ok := normalizeWindowsSSHAgentPipe(target); ok {
return target
}
return filepath.Join(filepath.Dir(source), target)
}
func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) {
if info, ok := parseGPGAssuanSocketInfo(data); ok {
return info, nil
}
if info, ok := parseGPGCygwinSocketInfo(data); ok {
return info, nil
}
return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path)
}
func parseGPGAssuanSocketRedirect(data []byte) (string, bool) {
text := strings.ReplaceAll(string(data), "\r\n", "\n")
text = strings.TrimSuffix(text, "\n")
lines := strings.Split(text, "\n")
if len(lines) != 2 || lines[0] != "%Assuan%" {
return "", false
}
target, ok := strings.CutPrefix(lines[1], "socket=")
if !ok || strings.TrimSpace(target) == "" {
return "", false
}
return os.ExpandEnv(target), true
}
func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) {
newline := bytes.IndexByte(data, '\n')
if newline <= 0 || len(data)-newline-1 != 16 {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 16)
copy(nonce, data[newline+1:])
return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true
}
func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) {
if !bytes.HasPrefix(data, []byte("!<socket >")) {
return gpgSocketInfo{}, false
}
fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00"))
if len(fields) != 3 || fields[1] != "s" {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(fields[0], 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
hexParts := strings.Split(fields[2], "-")
if len(hexParts) != 4 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 0, 16)
for _, part := range hexParts {
if len(part) != 8 {
return gpgSocketInfo{}, false
}
value, err := strconv.ParseUint(part, 16, 32)
if err != nil {
return gpgSocketInfo{}, false
}
var chunk [4]byte
binary.LittleEndian.PutUint32(chunk[:], uint32(value))
nonce = append(nonce, chunk[:]...)
}
return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true
}
func isAgentSSHSocketPath(endpoint string) bool {
normalized := strings.ToLower(strings.TrimSpace(endpoint))
return strings.HasSuffix(normalized, "s.gpg-agent.ssh")
}
func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
conn, err := winio.DialPipeContext(ctx, path)
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
return nil, errSSHAgentUnavailable
}
if err != nil {
return nil, err
}
return conn, nil
}

152
sshagent_windows_test.go Normal file
View File

@ -0,0 +1,152 @@
//go:build windows
package starssh
import (
"bytes"
"errors"
"io"
"net"
"os"
"path/filepath"
"strconv"
"testing"
"time"
)
func TestParseGPGAssuanSocketInfo(t *testing.T) {
info, ok := parseGPGAssuanSocketInfo([]byte("7247\n0123456789abcdef"))
if !ok {
t.Fatal("expected Assuan socket info to parse")
}
if info.port != 7247 || string(info.nonce) != "0123456789abcdef" || info.cygwin {
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
}
}
func TestParseGPGCygwinSocketInfo(t *testing.T) {
info, ok := parseGPGCygwinSocketInfo([]byte("!<socket >7247 s 00000001-02030405-06070809-0a0b0c0d\x00"))
if !ok {
t.Fatal("expected Cygwin socket info to parse")
}
want := []byte{1, 0, 0, 0, 5, 4, 3, 2, 9, 8, 7, 6, 13, 12, 11, 10}
if info.port != 7247 || string(info.nonce) != string(want) || !info.cygwin {
t.Fatalf("info=%+v nonce=%x", info, info.nonce)
}
}
func TestParseGPGAssuanSocketRedirect(t *testing.T) {
t.Setenv("STARSSH_TEST_PIPE", `\\.\pipe\openssh-ssh-agent`)
target, ok := parseGPGAssuanSocketRedirect([]byte("%Assuan%\r\nsocket=${STARSSH_TEST_PIPE}\r\n"))
if !ok {
t.Fatal("expected Assuan redirect to parse")
}
if target != `\\.\pipe\openssh-ssh-agent` {
t.Fatalf("target=%q", target)
}
}
func TestReadInvalidAgentSSHSocketReturnsGPGSocketError(t *testing.T) {
path := t.TempDir() + "/S.gpg-agent.ssh"
if err := os.WriteFile(path, []byte("not a socket info file"), 0o600); err != nil {
t.Fatalf("write socket file: %v", err)
}
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "SSH_AUTH_SOCK",
Network: defaultSSHAgentNetwork(path),
}, 0)
if !errors.Is(err, errInvalidGPGSocketInfo) {
t.Fatalf("err=%v want errInvalidGPGSocketInfo", err)
}
}
func TestMissingAgentSSHSocketReturnsReadError(t *testing.T) {
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(path),
}, 0)
if err == nil {
t.Fatal("expected missing GPG socket file error")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("err=%v want os.ErrNotExist", err)
}
}
func TestUnreadableAgentSSHSocketReturnsReadError(t *testing.T) {
path := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
if err := os.Mkdir(path, 0o700); err != nil {
t.Fatalf("mkdir socket path: %v", err)
}
_, err := dialResolvedSSHAgent(resolvedSSHAgentEndpoint{
Endpoint: path,
Source: "identity-agent",
Network: defaultSSHAgentNetwork(path),
}, 0)
if err == nil {
t.Fatal("expected unreadable GPG socket file error")
}
if errors.Is(err, errInvalidGPGSocketInfo) {
t.Fatalf("err=%v should expose read failure before parse", err)
}
}
func TestDialWindowsGPGSocketFilePerformsNonceHandshake(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer listener.Close()
type handshakeResult struct {
nonce []byte
err error
}
resultCh := make(chan handshakeResult, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
resultCh <- handshakeResult{err: err}
return
}
defer conn.Close()
nonce := make([]byte, 16)
if _, err := io.ReadFull(conn, nonce); err != nil {
resultCh <- handshakeResult{err: err}
return
}
resultCh <- handshakeResult{nonce: append([]byte(nil), nonce...)}
}()
socketPath := filepath.Join(t.TempDir(), "S.gpg-agent.ssh")
if err := os.WriteFile(socketPath, []byte(strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)+"\n0123456789abcdef"), 0o600); err != nil {
t.Fatalf("write socket file: %v", err)
}
conn, err := dialWindowsGPGSocketFile(socketPath, time.Second)
if err != nil {
t.Fatalf("dialWindowsGPGSocketFile: %v", err)
}
_ = conn.Close()
var result handshakeResult
select {
case result = <-resultCh:
case <-time.After(time.Second):
t.Fatal("listener did not accept GPG socket connection")
}
if result.err != nil {
t.Fatalf("listener handshake error: %v", result.err)
}
if !bytes.Equal(result.nonce, []byte("0123456789abcdef")) {
t.Fatalf("nonce=%q", result.nonce)
}
}

View File

@ -16,6 +16,7 @@ import (
const (
defaultSSHPort = 22
defaultLoginTimeout = 5 * time.Second
defaultSSHAgentTimeout = 2 * time.Minute
defaultKeepAliveTimeout = 3 * time.Second
defaultShellPollInterval = 120 * time.Millisecond
defaultShellSetupDelay = 200 * time.Millisecond
@ -58,6 +59,20 @@ const (
AuthMethodSSHAgent AuthMethodKind = "ssh_agent"
)
type SSHAgentDebugFunc func(SSHAgentDebugEvent)
type SSHAgentDebugEvent struct {
Step string
Source string
Endpoint string
Network string
Phase string
Status string
Duration time.Duration
KeyCount int
Err error
}
type StarSSH struct {
stateMu sync.RWMutex
Client *ssh.Client
@ -92,6 +107,10 @@ type LoginInput struct {
DisableSSHAgent bool
ForwardSSHAgent bool
AuthOrder []AuthMethodKind
// IdentityAgent overrides the local ssh-agent endpoint used for authentication
// and agent forwarding. Empty uses SSH_AUTH_SOCK, or the platform default where
// one exists.
IdentityAgent string
Addr string
Port int
// Timeout limits the SSH handshake/authentication phase after a TCP connection has
@ -101,6 +120,17 @@ type LoginInput struct {
// local ssh-agent socket connect. Zero falls back to Timeout when set, otherwise
// uses the package default dial timeout. Negative disables the default dial timeout.
DialTimeout time.Duration
// SSHAgentTimeout limits ssh-agent protocol operations such as listing keys and
// signing challenges. Zero uses the package default, and negative disables the
// per-operation deadline. This is intentionally separate from Timeout and
// DialTimeout because hardware-backed agents may require a PIN or touch confirmation.
SSHAgentTimeout time.Duration
// SSHAgentForwardTimeout limits idle reads and writes on forwarded agent
// channels. Zero or negative leaves forwarded channels without an idle deadline.
SSHAgentForwardTimeout time.Duration
// SSHAgentDebug receives structured ssh-agent dial/protocol events. It is nil by
// default and must not log private key material.
SSHAgentDebug SSHAgentDebugFunc
DialContext DialContextFunc
Proxy *ProxyConfig
Jump *LoginInput