291 lines
7.4 KiB
Go
291 lines
7.4 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
||
|
|
)
|
||
|
|
|
||
|
|
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
|
||
|
|
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
|
||
|
|
|
||
|
|
type HostKeyFingerprintMismatchError struct {
|
||
|
|
Expected []string
|
||
|
|
ActualSHA256 string
|
||
|
|
ActualLegacyMD5 string
|
||
|
|
}
|
||
|
|
|
||
|
|
type AcceptNewHostKeyOptions struct {
|
||
|
|
KnownHostsFile string
|
||
|
|
HashHosts bool
|
||
|
|
IncludeRemoteAddress bool
|
||
|
|
FileMode os.FileMode
|
||
|
|
}
|
||
|
|
|
||
|
|
func (e *HostKeyFingerprintMismatchError) Error() string {
|
||
|
|
if e == nil {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
|
||
|
|
}
|
||
|
|
|
||
|
|
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||
|
|
trimmed := make([]string, 0, len(files))
|
||
|
|
for _, file := range files {
|
||
|
|
file = strings.TrimSpace(file)
|
||
|
|
if file == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
trimmed = append(trimmed, file)
|
||
|
|
}
|
||
|
|
if len(trimmed) == 0 {
|
||
|
|
return nil, ErrKnownHostsFileRequired
|
||
|
|
}
|
||
|
|
return knownhosts.New(trimmed...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||
|
|
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
|
||
|
|
KnownHostsFile: file,
|
||
|
|
IncludeRemoteAddress: true,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||
|
|
options = normalizeAcceptNewHostKeyOptions(options)
|
||
|
|
if options.KnownHostsFile == "" {
|
||
|
|
return nil, ErrKnownHostsFileRequired
|
||
|
|
}
|
||
|
|
|
||
|
|
state := &acceptNewHostKeyState{
|
||
|
|
file: options.KnownHostsFile,
|
||
|
|
hashHosts: options.HashHosts,
|
||
|
|
includeRemoteAddress: options.IncludeRemoteAddress,
|
||
|
|
fileMode: options.FileMode,
|
||
|
|
}
|
||
|
|
if err := state.reload(); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return state.checkHostKey, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
|
||
|
|
normalized := make([]string, 0, len(fingerprints))
|
||
|
|
seen := make(map[string]struct{}, len(fingerprints))
|
||
|
|
for _, raw := range fingerprints {
|
||
|
|
fingerprint, err := normalizeHostKeyFingerprint(raw)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if fingerprint == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if _, exists := seen[fingerprint]; exists {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
seen[fingerprint] = struct{}{}
|
||
|
|
normalized = append(normalized, fingerprint)
|
||
|
|
}
|
||
|
|
if len(normalized) == 0 {
|
||
|
|
return nil, ErrHostFingerprintRequired
|
||
|
|
}
|
||
|
|
|
||
|
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||
|
|
actualSHA256 := ssh.FingerprintSHA256(key)
|
||
|
|
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
|
||
|
|
for _, want := range normalized {
|
||
|
|
if want == actualSHA256 || want == actualMD5 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return &HostKeyFingerprintMismatchError{
|
||
|
|
Expected: append([]string(nil), normalized...),
|
||
|
|
ActualSHA256: actualSHA256,
|
||
|
|
ActualLegacyMD5: actualMD5,
|
||
|
|
}
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type acceptNewHostKeyState struct {
|
||
|
|
file string
|
||
|
|
hashHosts bool
|
||
|
|
includeRemoteAddress bool
|
||
|
|
fileMode os.FileMode
|
||
|
|
|
||
|
|
mu sync.Mutex
|
||
|
|
cb ssh.HostKeyCallback
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
|
||
|
|
if s.cb != nil {
|
||
|
|
err := s.cb(hostname, remote, key)
|
||
|
|
if err == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var keyErr *knownhosts.KeyError
|
||
|
|
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err := s.reload(); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if s.cb == nil {
|
||
|
|
return errors.New("known_hosts callback is nil after reload")
|
||
|
|
}
|
||
|
|
return s.cb(hostname, remote, key)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *acceptNewHostKeyState) reload() error {
|
||
|
|
callback, err := loadKnownHostsCallback(s.file)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
s.cb = callback
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
|
||
|
|
_, err := os.Stat(file)
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, os.ErrNotExist) {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return knownhosts.New(file)
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
|
||
|
|
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
|
||
|
|
if options.FileMode == 0 {
|
||
|
|
options.FileMode = 0o600
|
||
|
|
}
|
||
|
|
return options
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
|
||
|
|
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
|
||
|
|
if len(addresses) == 0 {
|
||
|
|
return "", errors.New("no hostname or remote address available for known_hosts entry")
|
||
|
|
}
|
||
|
|
|
||
|
|
patterns := make([]string, 0, len(addresses))
|
||
|
|
for _, address := range addresses {
|
||
|
|
normalized := knownhosts.Normalize(address)
|
||
|
|
if hashHosts {
|
||
|
|
patterns = append(patterns, knownhosts.HashHostname(normalized))
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
patterns = append(patterns, normalized)
|
||
|
|
}
|
||
|
|
|
||
|
|
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
|
||
|
|
return strings.Join(patterns, ",") + " " + authorizedKey, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
|
||
|
|
addresses := make([]string, 0, 2)
|
||
|
|
seen := make(map[string]struct{}, 2)
|
||
|
|
|
||
|
|
add := func(address string) {
|
||
|
|
address = strings.TrimSpace(address)
|
||
|
|
if address == "" {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
normalized := knownhosts.Normalize(address)
|
||
|
|
if _, exists := seen[normalized]; exists {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
seen[normalized] = struct{}{}
|
||
|
|
addresses = append(addresses, address)
|
||
|
|
}
|
||
|
|
|
||
|
|
add(hostname)
|
||
|
|
if includeRemoteAddress && remote != nil {
|
||
|
|
add(remote.String())
|
||
|
|
}
|
||
|
|
|
||
|
|
return addresses
|
||
|
|
}
|
||
|
|
|
||
|
|
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
|
||
|
|
if strings.TrimSpace(file) == "" {
|
||
|
|
return ErrKnownHostsFileRequired
|
||
|
|
}
|
||
|
|
if strings.TrimSpace(line) == "" {
|
||
|
|
return errors.New("known_hosts line is empty")
|
||
|
|
}
|
||
|
|
|
||
|
|
dir := filepath.Dir(file)
|
||
|
|
if dir != "." && dir != "" {
|
||
|
|
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer handle.Close()
|
||
|
|
|
||
|
|
if _, err := handle.WriteString(line + "\n"); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return handle.Chmod(mode)
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeHostKeyFingerprint(raw string) (string, error) {
|
||
|
|
value := strings.TrimSpace(raw)
|
||
|
|
if value == "" {
|
||
|
|
return "", nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
|
||
|
|
suffix := strings.TrimSpace(value[len("SHA256:"):])
|
||
|
|
if suffix == "" {
|
||
|
|
return "", ErrHostFingerprintRequired
|
||
|
|
}
|
||
|
|
return "SHA256:" + suffix, nil
|
||
|
|
}
|
||
|
|
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||
|
|
suffix := strings.TrimSpace(value[len("MD5:"):])
|
||
|
|
if suffix == "" {
|
||
|
|
return "", ErrHostFingerprintRequired
|
||
|
|
}
|
||
|
|
return normalizeMD5Fingerprint("MD5:" + suffix), nil
|
||
|
|
}
|
||
|
|
if strings.Count(value, ":") >= 2 {
|
||
|
|
return normalizeMD5Fingerprint("MD5:" + value), nil
|
||
|
|
}
|
||
|
|
return "SHA256:" + value, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeMD5Fingerprint(value string) string {
|
||
|
|
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
|
||
|
|
return "MD5:" + strings.ToLower(value)
|
||
|
|
}
|
||
|
|
return "MD5:" + strings.ToLower(value[len("MD5:"):])
|
||
|
|
}
|