starssh/sshagent_auth.go

669 lines
17 KiB
Go
Raw Normal View History

package starssh
import (
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var errRetrySSHAgentAuth = errors.New("retry ssh-agent auth")
var buildSSHAgentAuthMethodFunc = buildSSHAgentAuthMethod
type sshAgentTimeouts struct {
Dial time.Duration
Operation time.Duration
Forward time.Duration
Endpoint string
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SkipFingerprints map[string]struct{}
SignFailure func(ssh.PublicKey, error)
}
type sshAgentAuthAttempt struct {
mu sync.Mutex
skipFingerprints map[string]struct{}
retryRequested bool
}
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func effectiveSSHAgentTimeout(info LoginInput) time.Duration {
switch {
case info.SSHAgentTimeout < 0:
return 0
case info.SSHAgentTimeout > 0:
return info.SSHAgentTimeout
default:
return defaultSSHAgentTimeout
}
}
func effectiveSSHAgentTimeouts(info LoginInput) sshAgentTimeouts {
return sshAgentTimeouts{
Dial: effectiveDialTimeout(info),
Operation: effectiveSSHAgentTimeout(info),
Forward: effectiveSSHAgentForwardTimeout(info),
Endpoint: info.IdentityAgent,
Debug: info.SSHAgentDebug,
}
}
func effectiveSSHAgentForwardTimeout(info LoginInput) time.Duration {
if info.SSHAgentForwardTimeout > 0 {
return info.SSHAgentForwardTimeout
}
return 0
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
return buildAuthMethodsWithAgentAttempt(info, nil)
}
func buildAuthMethodsWithAgentAttempt(info LoginInput, agentAttempt *sshAgentAuthAttempt) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info, agentAttempt)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback, agentAttempt)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
timeouts := effectiveSSHAgentTimeouts(info)
if agentAttempt != nil {
timeouts.SkipFingerprints = agentAttempt.skipSnapshot()
timeouts.SignFailure = agentAttempt.recordSignFailure
}
agentMethod, cleanup, err := buildSSHAgentAuthMethodFunc(timeouts)
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func shouldRetrySSHAgentAuth(info LoginInput, order []AuthMethodKind) bool {
if info.DisableSSHAgent {
return false
}
for _, methodKind := range order {
if methodKind == AuthMethodSSHAgent {
return true
}
}
return false
}
func buildPrivateKeyAuthMethod(info LoginInput, agentAttempt *sshAgentAuthAttempt) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeysCallback(privateKeySignersCallback(signer, agentAttempt)), nil
}
func privateKeySignersCallback(signer ssh.Signer, agentAttempt *sshAgentAuthAttempt) func() ([]ssh.Signer, error) {
return func() ([]ssh.Signer, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return []ssh.Signer{signer}, nil
}
}
func buildPasswordAuthMethod(password string, callback func() (string, error), agentAttempt *sshAgentAuthAttempt) ssh.AuthMethod {
if password == "" && callback == nil {
return nil
}
return ssh.PasswordCallback(func() (string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return "", err
}
if password != "" {
return password, nil
}
return callback()
})
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
agentAttempt *sshAgentAuthAttempt,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
return challenge(user, instruction, questions, echos)
})
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if err := checkSSHAgentRetryPending(agentAttempt); err != nil {
return nil, err
}
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeouts sshAgentTimeouts) (ssh.AuthMethod, func(), error) {
conn, resolved, err := dialSSHAgentWithDebug("auth", timeouts)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
conn = wrapSSHAgentConnWithDeadline(conn, timeouts.Operation)
started := time.Now()
signers, err := sshagent.NewClient(conn).Signers()
err = normalizeSSHAgentError(err)
logSSHAgentDebug(timeouts.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: resolved.Source,
Endpoint: resolved.Endpoint,
Network: resolved.Network,
Phase: "list",
Status: debugStatus(err),
Duration: time.Since(started),
KeyCount: len(signers),
Err: err,
})
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
timeouts.Resolved = resolved
orderedSigners := orderSSHAgentSigners(signers)
filteredSigners := filterSSHAgentSignersForRetry(orderedSigners, timeouts)
if len(filteredSigners) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no usable keys")
}
return ssh.PublicKeys(filteredSigners...), func() {
_ = conn.Close()
}, nil
}
func orderSSHAgentSigners(signers []ssh.Signer) []ssh.Signer {
type orderedSigner struct {
signer ssh.Signer
index int
score int
comment string
}
ordered := make([]orderedSigner, 0, len(signers))
for index, signer := range signers {
if signer == nil || signer.PublicKey() == nil {
continue
}
ordered = append(ordered, orderedSigner{
signer: signer,
index: index,
score: sshAgentSignerPriority(signer),
comment: sshAgentSignerComment(signer),
})
}
sort.SliceStable(ordered, func(i, j int) bool {
if ordered[i].score != ordered[j].score {
return ordered[i].score > ordered[j].score
}
return ordered[i].index < ordered[j].index
})
result := make([]ssh.Signer, 0, len(ordered))
for _, item := range ordered {
result = append(result, item.signer)
}
return result
}
func sshAgentSignerComment(signer ssh.Signer) string {
if signer == nil {
return ""
}
if key, ok := signer.PublicKey().(*sshagent.Key); ok {
return key.Comment
}
return ""
}
func sshAgentSignerPriority(signer ssh.Signer) int {
comment := strings.TrimSpace(sshAgentSignerComment(signer))
if comment == "" {
return 0
}
score := 0
if priority, ok := parseSSHAgentSignerPriority(comment); ok {
score += 100000 + priority*1000
}
lower := strings.ToLower(comment)
if strings.Contains(lower, "current") {
score += 400
}
if strings.Contains(lower, "cardno:") {
score += 300
}
if strings.Contains(lower, "card ") || strings.Contains(lower, " card") || strings.Contains(lower, "card:") {
score += 100
}
if strings.Contains(lower, "openpgp") || strings.Contains(lower, "gpg") {
score += 50
}
return score
}
func parseSSHAgentSignerPriority(comment string) (int, bool) {
lower := strings.ToLower(comment)
index := strings.Index(lower, "priority=")
if index < 0 {
return 0, false
}
value := strings.TrimSpace(comment[index+len("priority="):])
if value == "" {
return 0, false
}
end := 0
for end < len(value) {
ch := value[end]
if ch == '+' || ch == '-' || (ch >= '0' && ch <= '9') {
end++
continue
}
break
}
if end == 0 {
return 0, false
}
priority, err := strconv.Atoi(value[:end])
if err != nil {
return 0, false
}
return priority, true
}
func filterSSHAgentSignersForRetry(signers []ssh.Signer, timeouts sshAgentTimeouts) []ssh.Signer {
filteredSigners := make([]ssh.Signer, 0, len(signers))
for _, signer := range signers {
if signer == nil {
continue
}
publicKey := signer.PublicKey()
if publicKey == nil {
continue
}
if _, skip := timeouts.SkipFingerprints[ssh.FingerprintSHA256(publicKey)]; skip {
continue
}
if timeouts.SignFailure == nil && timeouts.Debug == nil {
filteredSigners = append(filteredSigners, signer)
continue
}
filteredSigners = append(filteredSigners, wrapSSHAgentSigner(signer, sshAgentSignerOptions{
Resolved: timeouts.Resolved,
Debug: timeouts.Debug,
SignFailure: timeouts.SignFailure,
}))
}
return filteredSigners
}
func newSSHAgentAuthAttempt() *sshAgentAuthAttempt {
return &sshAgentAuthAttempt{
skipFingerprints: make(map[string]struct{}),
}
}
func (a *sshAgentAuthAttempt) begin() {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = false
}
func (a *sshAgentAuthAttempt) skipSnapshot() map[string]struct{} {
if a == nil {
return nil
}
a.mu.Lock()
defer a.mu.Unlock()
if len(a.skipFingerprints) == 0 {
return nil
}
snapshot := make(map[string]struct{}, len(a.skipFingerprints))
for fingerprint := range a.skipFingerprints {
snapshot[fingerprint] = struct{}{}
}
return snapshot
}
func (a *sshAgentAuthAttempt) recordSignFailure(publicKey ssh.PublicKey, err error) {
_ = err
if a == nil || publicKey == nil {
return
}
a.skipFingerprint(ssh.FingerprintSHA256(publicKey))
}
func (a *sshAgentAuthAttempt) skipFingerprint(fingerprint string) {
if a == nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.retryRequested = true
if fingerprint != "" {
a.skipFingerprints[fingerprint] = struct{}{}
}
}
func (a *sshAgentAuthAttempt) shouldRetry() bool {
if a == nil {
return false
}
a.mu.Lock()
defer a.mu.Unlock()
return a.retryRequested
}
func checkSSHAgentRetryPending(agentAttempt *sshAgentAuthAttempt) error {
if agentAttempt != nil && agentAttempt.shouldRetry() {
return errRetrySSHAgentAuth
}
return nil
}
type sshAgentRetrySigner struct {
signer ssh.Signer
publicKey ssh.PublicKey
options sshAgentSignerOptions
}
type sshAgentRetryAlgorithmSigner struct {
sshAgentRetrySigner
algorithmSigner ssh.AlgorithmSigner
}
type sshAgentRetryMultiAlgorithmSigner struct {
sshAgentRetryAlgorithmSigner
multiAlgorithmSigner ssh.MultiAlgorithmSigner
}
type sshAgentSignerOptions struct {
Resolved resolvedSSHAgentEndpoint
Debug SSHAgentDebugFunc
SignFailure func(ssh.PublicKey, error)
}
func wrapSSHAgentSignerForRetry(signer ssh.Signer, onFailure func(ssh.PublicKey, error)) ssh.Signer {
return wrapSSHAgentSigner(signer, sshAgentSignerOptions{SignFailure: onFailure})
}
func wrapSSHAgentSigner(signer ssh.Signer, options sshAgentSignerOptions) ssh.Signer {
publicKey := signer.PublicKey()
base := sshAgentRetrySigner{
signer: signer,
publicKey: publicKey,
options: options,
}
if multiAlgorithmSigner, ok := signer.(ssh.MultiAlgorithmSigner); ok {
return &sshAgentRetryMultiAlgorithmSigner{
sshAgentRetryAlgorithmSigner: sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: multiAlgorithmSigner,
},
multiAlgorithmSigner: multiAlgorithmSigner,
}
}
if algorithmSigner, ok := signer.(ssh.AlgorithmSigner); ok {
return &sshAgentRetryAlgorithmSigner{
sshAgentRetrySigner: base,
algorithmSigner: algorithmSigner,
}
}
return &base
}
func (s *sshAgentRetrySigner) PublicKey() ssh.PublicKey {
return s.publicKey
}
func (s *sshAgentRetrySigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
started := time.Now()
signature, err := s.signer.Sign(rand, data)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetrySigner) finishSign(started time.Time, err error) error {
err = normalizeSSHAgentError(err)
s.logSignDebug(started, err)
if err == nil {
return nil
}
if s.options.SignFailure != nil {
s.options.SignFailure(s.publicKey, err)
return wrapSSHAgentSignError(err)
}
return err
}
func (s *sshAgentRetrySigner) logSignDebug(started time.Time, err error) {
if s == nil || s.options.Debug == nil {
return
}
logSSHAgentDebug(s.options.Debug, SSHAgentDebugEvent{
Step: "auth",
Source: s.options.Resolved.Source,
Endpoint: s.options.Resolved.Endpoint,
Network: s.options.Resolved.Network,
Phase: "sign",
Status: debugStatus(err),
Duration: time.Since(started),
Err: err,
})
}
func (s *sshAgentRetryAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, nil)
started := time.Now()
signature, err := s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func (s *sshAgentRetryMultiAlgorithmSigner) Algorithms() []string {
return s.multiAlgorithmSigner.Algorithms()
}
func (s *sshAgentRetryMultiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
algorithm = preferredSSHAgentSignAlgorithm(s.publicKey, algorithm, s.multiAlgorithmSigner.Algorithms())
started := time.Now()
signature, err := s.multiAlgorithmSigner.SignWithAlgorithm(rand, data, algorithm)
return signature, s.finishSign(started, err)
}
func preferredSSHAgentSignAlgorithm(publicKey ssh.PublicKey, requested string, algorithms []string) string {
if publicKey == nil || publicKey.Type() != ssh.KeyAlgoRSA || requested != ssh.KeyAlgoRSA {
return requested
}
if len(algorithms) == 0 {
return ssh.KeyAlgoRSASHA256
}
for _, algorithm := range algorithms {
if algorithm == ssh.KeyAlgoRSA {
break
}
if algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512 {
return algorithm
}
}
return requested
}
func wrapSSHAgentSignError(err error) error {
if err == nil {
return nil
}
return fmt.Errorf("%w: %v", errRetrySSHAgentAuth, normalizeSSHAgentError(err))
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}