package starssh import ( "errors" "fmt" "io" "sort" "strconv" "strings" "sync" "time" "golang.org/x/crypto/ssh" sshagent "golang.org/x/crypto/ssh/agent" ) var errSSHAgentUnavailable = errors.New("ssh-agent unavailable") var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth") var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod type sshAgentTimeouts struct { Dial time.Duration Operation time.Duration Forward time.Duration Endpoint string Resolved resolvedSSHAgentEndpoint Debug SSHAgentDebugFunc SkipFingerprints map[string]struct{} SignFailure func(ssh.PublicKey, error) } type sshAgentAuthAttempt struct { mu sync.Mutex skipFingerprints map[string]struct{} retryRequested bool } var defaultAuthOrder = []AuthMethodKind{ AuthMethodSSHAgent, AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, } func effectiveSSHAgentTimeout(info LoginInput) time.Duration { switch { case info.SSHAgentTimeout < 0: return 0 case info.SSHAgentTimeout > 0: return info.SSHAgentTimeout default: return defaultSSHAgentTimeout } } func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts { return sshAgentTimeouts{ Dial: effectiveDialTimeout(info), Operation: effectiveSSHAgentTimeout(info), Forward: effectiveSSHAgentForwardTimeout(info), Endpoint: info.IdentityAgent, Debug: info.SSHAgentDebug, } } func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration { if info.SSHAgentForwardTimeout > 0 { return info.SSHAgentForwardTimeout } return 0 } func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) { return buildAuthMethodsWithAgentAttempt(info, nil) } func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) { order, err := normalizeAuthOrder(info.AuthOrder) if err != nil { return nil, nil, err } auth := make([]ssh.AuthMethod, 0, len(order)) var agentErr error var cleanupFuncs []func() for _, methodKind := range order { switch methodKind { case AuthMethodPrivateKey: method, err := buildPrivateKeyAuthMethod(info, agentAttempt) if err != nil { return nil, nil, err } if method != nil { auth = append(auth, method) } case AuthMethodPassword: method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt) if method != nil { auth = append(auth, method) } case AuthMethodKeyboardInteractive: method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt) if method != nil { auth = append(auth, method) } case AuthMethodSSHAgent: if info.DisableSSHAgent { continue } timeouts := effectiveSSHAgentTimeouts(info) if agentAttempt != nil { timeouts.SkipFingerprints = agentAttempt.skipSnapshot() timeouts.SignFailure = agentAttempt.recordSignFailure } agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts) if err != nil { agentErr = err continue } if agentMethod != nil { auth = append(auth, agentMethod) } if cleanup != nil { cleanupFuncs = append(cleanupFuncs, cleanup) } } } if len(auth) == 0 { if agentErr != nil { return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr) } return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required") } return auth, composeCleanup(cleanupFuncs...), nil } func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) { if len(order) == 0 { return append([]AuthMethodKind(nil), defaultAuthOrder...), nil } normalized := make([]AuthMethodKind, 0, len(order)) seen := make(map[AuthMethodKind]struct{}, len(order)) for _, raw := range order { kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw)))) if kind == "" { return nil, errors.New("auth order contains an empty auth method") } if !isSupportedAuthMethodKind(kind) { return nil, fmt.Errorf("unsupported auth method %q", raw) } if _, exists := seen[kind]; exists { continue } seen[kind] = struct{}{} normalized = append(normalized, kind) } if len(normalized) == 0 { return nil, errors.New("auth order is empty") } return normalized, nil } func isSupportedAuthMethodKind(kind AuthMethodKind) bool { switch kind { case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent: return true default: return false } } func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool { if info.DisableSSHAgent { return false } for _, methodKind := range order { if methodKind == AuthMethodSSHAgent { return true } } return false } func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) { if strings.TrimSpace(info.Prikey) == "" { return nil, nil } pemBytes := []byte(info.Prikey) if info.PrikeyPwd == "" { signer, err := ssh.ParsePrivateKey(pemBytes) if err != nil { return nil, err } return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil } signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd)) if err != nil { return nil, err } return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil } func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) { return func() ([]ssh.Signer, error) { if err := checkSSHAgentRetryPending(agentAttempt); err != nil { return nil, err } return []ssh.Signer{signer}, nil } } func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod { if password == "" && callback == nil { return nil } return ssh.PasswordCallback(func() (string, error) { if err := checkSSHAgentRetryPending(agentAttempt); err != nil { return "", err } if password != "" { return password, nil } return callback() }) } func buildKeyboardInteractiveAuthMethod( password string, passwordCallback func() (string, error), challenge ssh.KeyboardInteractiveChallenge, agentAttempt *sshAgentAuthAttempt, ) ssh.AuthMethod { if challenge != nil { return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { if err := checkSSHAgentRetryPending(agentAttempt); err != nil { return nil, err } return challenge(user, instruction, questions, echos) }) } if password == "" && passwordCallback == nil { return nil } keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) { if err := checkSSHAgentRetryPending(agentAttempt); err != nil { return nil, err } if len(questions) == 0 { return []string{}, nil } answer := password if answer == "" { var err error answer, err = passwordCallback() if err != nil { return nil, err } } answers := make([]string, len(questions)) for i := range questions { answers[i] = answer } return answers, nil } return ssh.KeyboardInteractive(keyboardInteractiveChallenge) } func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) { conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts) if err != nil { if errors.Is(err, errSSHAgentUnavailable) { return nil, nil, nil } return nil, nil, err } if conn == nil { return nil, nil, nil } conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation) started := time.Now() signers, err := sshagent.NewClient(conn).Signers() err = normalizeSSHAgentError(err) logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{ Step: "auth", Source: resolved.Source, Endpoint: resolved.Endpoint, Network: resolved.Network, Phase: "list", Status: debugStatus(err), Duration: time.Since(started), KeyCount: len(signers), Err: err, }) if err != nil { _ = conn.Close() return nil, nil, err } if len(signers) == 0 { _ = conn.Close() return nil, nil, errors.New("ssh-agent has no loaded keys") } timeouts.Resolved = resolved orderedSigners := orderSSHAgentSigners(signers) filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts) if len(filteredSigners) == 0 { _ = conn.Close() return nil, nil, errors.New("ssh-agent has no usable keys") } return ssh.PublicKeys(filteredSigners...), func() { _ = conn.Close() }, nil } func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer { type orderedSigner struct { signer ssh.Signer index int score int comment string } ordered := make([]orderedSigner, 0, len(signers)) for index, signer := range signers { if signer == nil || signer.PublicKey() == nil { continue } ordered = append(ordered, orderedSigner{ signer: signer, index: index, score: sshAgentSignerPriority(signer), comment: sshAgentSignerComment(signer), }) } sort.SliceStable(ordered, func(i, j int) bool { if ordered[i].score != ordered[j].score { return ordered[i].score > ordered[j].score } return ordered[i].index < ordered[j].index }) result := make([]ssh.Signer, 0, len(ordered)) for _, item := range ordered { result = append(result, item.signer) } return result } func sshAgentSignerComment(signer ssh.Signer) string { if signer == nil { return "" } if key, ok := signer.PublicKey().(*sshagent.Key); ok { return key.Comment } return "" } func sshAgentSignerPriority(signer ssh.Signer) int { comment := strings.TrimSpace(sshAgentSignerComment(signer)) if comment == "" { return 0 } score := 0 if priority, ok := parseSSHAgentSignerPriority(comment); ok { score += 100000 + priority*1000 } lower := strings.ToLower(comment) if strings.Contains(lower, "current") { score += 400 } if strings.Contains(lower, "cardno:") { score += 300 } if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") { score += 100 } if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") { score += 50 } return score } func parseSSHAgentSignerPriority(comment string) (int, bool) { lower := strings.ToLower(comment) index := strings.Index(lower, "priority=") if index < 0 { return 0, false } value := strings.TrimSpace(comment[index+len("priority="):]) if value == "" { return 0, false } end := 0 for end < len(value) { ch := value[end] if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') { end++ continue } break } if end == 0 { return 0, false } priority, err := strconv.Atoi(value[:end]) if err != nil { return 0, false } return priority, true } func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer { filteredSigners := make([]ssh.Signer, 0, len(signers)) for _, signer := range signers { if signer == nil { continue } publicKey := signer.PublicKey() if publicKey == nil { continue } if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip { continue } if timeouts.SignFailure == nil && timeouts.Debug == nil { filteredSigners = append(filteredSigners, signer) continue } filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{ Resolved: timeouts.Resolved, Debug: timeouts.Debug, SignFailure: timeouts.SignFailure, })) } return filteredSigners } func newSSHAgentAuthAttempt() *sshAgentAuthAttempt { return &sshAgentAuthAttempt{ skipFingerprints: make(map[string]struct{}), } } func (a *sshAgentAuthAttempt) begin() { if a == nil { return } a.mu.Lock() defer a.mu.Unlock() a.retryRequested = false } func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} { if a == nil { return nil } a.mu.Lock() defer a.mu.Unlock() if len(a.skipFingerprints) == 0 { return nil } snapshot := make(map[string]struct{}, len(a.skipFingerprints)) for fingerprint := range a.skipFingerprints { snapshot[fingerprint] = struct{}{} } return snapshot } func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) { _ = err if a == nil || publicKey == nil { return } a.skipFingerprint(ssh.FingerprintSHA256(publicKey)) } func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) { if a == nil { return } a.mu.Lock() defer a.mu.Unlock() a.retryRequested = true if fingerprint != "" { a.skipFingerprints[fingerprint] = struct{}{} } } func (a *sshAgentAuthAttempt) shouldRetry() bool { if a == nil { return false } a.mu.Lock() defer a.mu.Unlock() return a.retryRequested } func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error { if agentAttempt != nil && agentAttempt.shouldRetry() { return errRetrySSHAgentAuth } return nil } type sshAgentRetrySigner struct { signer ssh.Signer publicKey ssh.PublicKey options sshAgentSignerOptions } type sshAgentRetryAlgorithmSigner struct { sshAgentRetrySigner algorithmSigner ssh.AlgorithmSigner } type sshAgentRetryMultiAlgorithmSigner struct { sshAgentRetryAlgorithmSigner multiAlgorithmSigner ssh.MultiAlgorithmSigner } type sshAgentSignerOptions struct { Resolved resolvedSSHAgentEndpoint Debug SSHAgentDebugFunc SignFailure func(ssh.PublicKey, error) } func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer { return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure}) } func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer { publicKey := signer.PublicKey() base := sshAgentRetrySigner{ signer: signer, publicKey: publicKey, options: options, } if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok { return &sshAgentRetryMultiAlgorithmSigner{ sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{ sshAgentRetrySigner: base, algorithmSigner: multiAlgorithmSigner, }, multiAlgorithmSigner: multiAlgorithmSigner, } } if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok { return &sshAgentRetryAlgorithmSigner{ sshAgentRetrySigner: base, algorithmSigner: algorithmSigner, } } return &base } func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey { return s.publicKey } func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { started := time.Now() signature, err := s.signer.Sign(rand, data) return signature, s.finishSign(started, err) } func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error { err = normalizeSSHAgentError(err) s.logSignDebug(started, err) if err == nil { return nil } if s.options.SignFailure != nil { s.options.SignFailure(s.publicKey, err) return wrapSSHAgentSignError(err) } return err } func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) { if s == nil || s.options.Debug == nil { return } logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{ Step: "auth", Source: s.options.Resolved.Source, Endpoint: s.options.Resolved.Endpoint, Network: s.options.Resolved.Network, Phase: "sign", Status: debugStatus(err), Duration: time.Since(started), Err: err, }) } func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil) started := time.Now() signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) return signature, s.finishSign(started, err) } func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string { return s.multiAlgorithmSigner.Algorithms() } func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) { algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms()) started := time.Now() signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) return signature, s.finishSign(started, err) } func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string { if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA { return requested } if len(algorithms) == 0 { return ssh.KeyAlgoRSASHA256 } for _, algorithm := range algorithms { if algorithm == ssh.KeyAlgoRSA { break } if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 { return algorithm } } return requested } func wrapSSHAgentSignError(err error) error { if err == nil { return nil } return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err)) } func composeCleanup(funcs ...func()) func() { if len(funcs) == 0 { return nil } return func() { for i := len(funcs) - 1; i >= 0; i-- { if funcs[i] != nil { funcs[i]() } } } }