Files
starssh/hostkey.go
T

291 lines
7.4 KiB
Go
Raw Permalink Normal View History

package starssh
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
type HostKeyFingerprintMismatchError struct {
Expected []string
ActualSHA256 string
ActualLegacyMD5 string
}
type AcceptNewHostKeyOptions struct {
KnownHostsFile string
HashHosts bool
IncludeRemoteAddress bool
FileMode os.FileMode
}
func (e *HostKeyFingerprintMismatchError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
}
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
trimmed := make([]string, 0, len(files))
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
trimmed = append(trimmed, file)
}
if len(trimmed) == 0 {
return nil, ErrKnownHostsFileRequired
}
return knownhosts.New(trimmed...)
}
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
KnownHostsFile: file,
IncludeRemoteAddress: true,
})
}
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
options = normalizeAcceptNewHostKeyOptions(options)
if options.KnownHostsFile == "" {
return nil, ErrKnownHostsFileRequired
}
state := &acceptNewHostKeyState{
file: options.KnownHostsFile,
hashHosts: options.HashHosts,
includeRemoteAddress: options.IncludeRemoteAddress,
fileMode: options.FileMode,
}
if err := state.reload(); err != nil {
return nil, err
}
return state.checkHostKey, nil
}
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
normalized := make([]string, 0, len(fingerprints))
seen := make(map[string]struct{}, len(fingerprints))
for _, raw := range fingerprints {
fingerprint, err := normalizeHostKeyFingerprint(raw)
if err != nil {
return nil, err
}
if fingerprint == "" {
continue
}
if _, exists := seen[fingerprint]; exists {
continue
}
seen[fingerprint] = struct{}{}
normalized = append(normalized, fingerprint)
}
if len(normalized) == 0 {
return nil, ErrHostFingerprintRequired
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
actualSHA256 := ssh.FingerprintSHA256(key)
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
for _, want := range normalized {
if want == actualSHA256 || want == actualMD5 {
return nil
}
}
return &HostKeyFingerprintMismatchError{
Expected: append([]string(nil), normalized...),
ActualSHA256: actualSHA256,
ActualLegacyMD5: actualMD5,
}
}, nil
}
type acceptNewHostKeyState struct {
file string
hashHosts bool
includeRemoteAddress bool
fileMode os.FileMode
mu sync.Mutex
cb ssh.HostKeyCallback
}
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cb != nil {
err := s.cb(hostname, remote, key)
if err == nil {
return nil
}
var keyErr *knownhosts.KeyError
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
return err
}
}
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
if err != nil {
return err
}
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
return err
}
if err := s.reload(); err != nil {
return err
}
if s.cb == nil {
return errors.New("known_hosts callback is nil after reload")
}
return s.cb(hostname, remote, key)
}
func (s *acceptNewHostKeyState) reload() error {
callback, err := loadKnownHostsCallback(s.file)
if err != nil {
return err
}
s.cb = callback
return nil
}
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
_, err := os.Stat(file)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, err
}
return knownhosts.New(file)
}
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
if options.FileMode == 0 {
options.FileMode = 0o600
}
return options
}
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
if len(addresses) == 0 {
return "", errors.New("no hostname or remote address available for known_hosts entry")
}
patterns := make([]string, 0, len(addresses))
for _, address := range addresses {
normalized := knownhosts.Normalize(address)
if hashHosts {
patterns = append(patterns, knownhosts.HashHostname(normalized))
continue
}
patterns = append(patterns, normalized)
}
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
return strings.Join(patterns, ",") + " " + authorizedKey, nil
}
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
addresses := make([]string, 0, 2)
seen := make(map[string]struct{}, 2)
add := func(address string) {
address = strings.TrimSpace(address)
if address == "" {
return
}
normalized := knownhosts.Normalize(address)
if _, exists := seen[normalized]; exists {
return
}
seen[normalized] = struct{}{}
addresses = append(addresses, address)
}
add(hostname)
if includeRemoteAddress && remote != nil {
add(remote.String())
}
return addresses
}
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
if strings.TrimSpace(file) == "" {
return ErrKnownHostsFileRequired
}
if strings.TrimSpace(line) == "" {
return errors.New("known_hosts line is empty")
}
dir := filepath.Dir(file)
if dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
}
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
if err != nil {
return err
}
defer handle.Close()
if _, err := handle.WriteString(line + "\n"); err != nil {
return err
}
return handle.Chmod(mode)
}
func normalizeHostKeyFingerprint(raw string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", nil
}
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
suffix := strings.TrimSpace(value[len("SHA256:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return "SHA256:" + suffix, nil
}
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
suffix := strings.TrimSpace(value[len("MD5:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return normalizeMD5Fingerprint("MD5:" + suffix), nil
}
if strings.Count(value, ":") >= 2 {
return normalizeMD5Fingerprint("MD5:" + value), nil
}
return "SHA256:" + value, nil
}
func normalizeMD5Fingerprint(value string) string {
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
return "MD5:" + strings.ToLower(value)
}
return "MD5:" + strings.ToLower(value[len("MD5:"):])
}