package starssh import ( "errors" "fmt" "net" "os" "path/filepath" "strings" "sync" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" ) var ErrKnownHostsFileRequired = errors.New("known_hosts file is required") var ErrHostFingerprintRequired = errors.New("host key fingerprint is required") type HostKeyFingerprintMismatchError struct { Expected []string ActualSHA256 string ActualLegacyMD5 string } type AcceptNewHostKeyOptions struct { KnownHostsFile string HashHosts bool IncludeRemoteAddress bool FileMode os.FileMode } func (e *HostKeyFingerprintMismatchError) Error() string { if e == nil { return "" } return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5) } func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) { trimmed := make([]string, 0, len(files)) for _, file := range files { file = strings.TrimSpace(file) if file == "" { continue } trimmed = append(trimmed, file) } if len(trimmed) == 0 { return nil, ErrKnownHostsFileRequired } return knownhosts.New(trimmed...) } func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) { return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{ KnownHostsFile: file, IncludeRemoteAddress: true, }) } func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) { options = normalizeAcceptNewHostKeyOptions(options) if options.KnownHostsFile == "" { return nil, ErrKnownHostsFileRequired } state := &acceptNewHostKeyState{ file: options.KnownHostsFile, hashHosts: options.HashHosts, includeRemoteAddress: options.IncludeRemoteAddress, fileMode: options.FileMode, } if err := state.reload(); err != nil { return nil, err } return state.checkHostKey, nil } func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) { normalized := make([]string, 0, len(fingerprints)) seen := make(map[string]struct{}, len(fingerprints)) for _, raw := range fingerprints { fingerprint, err := normalizeHostKeyFingerprint(raw) if err != nil { return nil, err } if fingerprint == "" { continue } if _, exists := seen[fingerprint]; exists { continue } seen[fingerprint] = struct{}{} normalized = append(normalized, fingerprint) } if len(normalized) == 0 { return nil, ErrHostFingerprintRequired } return func(hostname string, remote net.Addr, key ssh.PublicKey) error { actualSHA256 := ssh.FingerprintSHA256(key) actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key)) for _, want := range normalized { if want == actualSHA256 || want == actualMD5 { return nil } } return &HostKeyFingerprintMismatchError{ Expected: append([]string(nil), normalized...), ActualSHA256: actualSHA256, ActualLegacyMD5: actualMD5, } }, nil } type acceptNewHostKeyState struct { file string hashHosts bool includeRemoteAddress bool fileMode os.FileMode mu sync.Mutex cb ssh.HostKeyCallback } func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error { s.mu.Lock() defer s.mu.Unlock() if s.cb != nil { err := s.cb(hostname, remote, key) if err == nil { return nil } var keyErr *knownhosts.KeyError if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 { return err } } line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress) if err != nil { return err } if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil { return err } if err := s.reload(); err != nil { return err } if s.cb == nil { return errors.New("known_hosts callback is nil after reload") } return s.cb(hostname, remote, key) } func (s *acceptNewHostKeyState) reload() error { callback, err := loadKnownHostsCallback(s.file) if err != nil { return err } s.cb = callback return nil } func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) { _, err := os.Stat(file) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil, nil } return nil, err } return knownhosts.New(file) } func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions { options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile) if options.FileMode == 0 { options.FileMode = 0o600 } return options } func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) { addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress) if len(addresses) == 0 { return "", errors.New("no hostname or remote address available for known_hosts entry") } patterns := make([]string, 0, len(addresses)) for _, address := range addresses { normalized := knownhosts.Normalize(address) if hashHosts { patterns = append(patterns, knownhosts.HashHostname(normalized)) continue } patterns = append(patterns, normalized) } authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))) return strings.Join(patterns, ",") + " " + authorizedKey, nil } func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string { addresses := make([]string, 0, 2) seen := make(map[string]struct{}, 2) add := func(address string) { address = strings.TrimSpace(address) if address == "" { return } normalized := knownhosts.Normalize(address) if _, exists := seen[normalized]; exists { return } seen[normalized] = struct{}{} addresses = append(addresses, address) } add(hostname) if includeRemoteAddress && remote != nil { add(remote.String()) } return addresses } func appendKnownHostsLine(file string, line string, mode os.FileMode) error { if strings.TrimSpace(file) == "" { return ErrKnownHostsFileRequired } if strings.TrimSpace(line) == "" { return errors.New("known_hosts line is empty") } dir := filepath.Dir(file) if dir != "." && dir != "" { if err := os.MkdirAll(dir, 0o700); err != nil { return err } } handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode) if err != nil { return err } defer handle.Close() if _, err := handle.WriteString(line + "\n"); err != nil { return err } return handle.Chmod(mode) } func normalizeHostKeyFingerprint(raw string) (string, error) { value := strings.TrimSpace(raw) if value == "" { return "", nil } if strings.HasPrefix(strings.ToUpper(value), "SHA256:") { suffix := strings.TrimSpace(value[len("SHA256:"):]) if suffix == "" { return "", ErrHostFingerprintRequired } return "SHA256:" + suffix, nil } if strings.HasPrefix(strings.ToUpper(value), "MD5:") { suffix := strings.TrimSpace(value[len("MD5:"):]) if suffix == "" { return "", ErrHostFingerprintRequired } return normalizeMD5Fingerprint("MD5:" + suffix), nil } if strings.Count(value, ":") >= 2 { return normalizeMD5Fingerprint("MD5:" + value), nil } return "SHA256:" + value, nil } func normalizeMD5Fingerprint(value string) string { if !strings.HasPrefix(strings.ToUpper(value), "MD5:") { return "MD5:" + strings.ToLower(value) } return "MD5:" + strings.ToLower(value[len("MD5:"):]) }