starssh/hostkey.go

291 lines
7.4 KiB
Go
Raw Normal View History

refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
package starssh
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
var ErrKnownHostsFileRequired = errors.New("known_hosts file is required")
var ErrHostFingerprintRequired = errors.New("host key fingerprint is required")
type HostKeyFingerprintMismatchError struct {
Expected []string
ActualSHA256 string
ActualLegacyMD5 string
}
type AcceptNewHostKeyOptions struct {
KnownHostsFile string
HashHosts bool
IncludeRemoteAddress bool
FileMode os.FileMode
}
func (e *HostKeyFingerprintMismatchError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("host key fingerprint mismatch: want one of %s, got %s (%s)", strings.Join(e.Expected, ", "), e.ActualSHA256, e.ActualLegacyMD5)
}
func KnownHostsHostKeyCallback(files ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
trimmed := make([]string, 0, len(files))
for _, file := range files {
file = strings.TrimSpace(file)
if file == "" {
continue
}
trimmed = append(trimmed, file)
}
if len(trimmed) == 0 {
return nil, ErrKnownHostsFileRequired
}
return knownhosts.New(trimmed...)
}
func AcceptNewHostKeyCallback(file string) (func(string, net.Addr, ssh.PublicKey) error, error) {
return AcceptNewHostKeyCallbackWithOptions(AcceptNewHostKeyOptions{
KnownHostsFile: file,
IncludeRemoteAddress: true,
})
}
func AcceptNewHostKeyCallbackWithOptions(options AcceptNewHostKeyOptions) (func(string, net.Addr, ssh.PublicKey) error, error) {
options = normalizeAcceptNewHostKeyOptions(options)
if options.KnownHostsFile == "" {
return nil, ErrKnownHostsFileRequired
}
state := &acceptNewHostKeyState{
file: options.KnownHostsFile,
hashHosts: options.HashHosts,
includeRemoteAddress: options.IncludeRemoteAddress,
fileMode: options.FileMode,
}
if err := state.reload(); err != nil {
return nil, err
}
return state.checkHostKey, nil
}
func FingerprintHostKeyCallback(fingerprints ...string) (func(string, net.Addr, ssh.PublicKey) error, error) {
normalized := make([]string, 0, len(fingerprints))
seen := make(map[string]struct{}, len(fingerprints))
for _, raw := range fingerprints {
fingerprint, err := normalizeHostKeyFingerprint(raw)
if err != nil {
return nil, err
}
if fingerprint == "" {
continue
}
if _, exists := seen[fingerprint]; exists {
continue
}
seen[fingerprint] = struct{}{}
normalized = append(normalized, fingerprint)
}
if len(normalized) == 0 {
return nil, ErrHostFingerprintRequired
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
actualSHA256 := ssh.FingerprintSHA256(key)
actualMD5 := normalizeMD5Fingerprint(ssh.FingerprintLegacyMD5(key))
for _, want := range normalized {
if want == actualSHA256 || want == actualMD5 {
return nil
}
}
return &HostKeyFingerprintMismatchError{
Expected: append([]string(nil), normalized...),
ActualSHA256: actualSHA256,
ActualLegacyMD5: actualMD5,
}
}, nil
}
type acceptNewHostKeyState struct {
file string
hashHosts bool
includeRemoteAddress bool
fileMode os.FileMode
mu sync.Mutex
cb ssh.HostKeyCallback
}
func (s *acceptNewHostKeyState) checkHostKey(hostname string, remote net.Addr, key ssh.PublicKey) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cb != nil {
err := s.cb(hostname, remote, key)
if err == nil {
return nil
}
var keyErr *knownhosts.KeyError
if !errors.As(err, &keyErr) || len(keyErr.Want) != 0 {
return err
}
}
line, err := buildAcceptNewKnownHostsLine(hostname, remote, key, s.hashHosts, s.includeRemoteAddress)
if err != nil {
return err
}
if err := appendKnownHostsLine(s.file, line, s.fileMode); err != nil {
return err
}
if err := s.reload(); err != nil {
return err
}
if s.cb == nil {
return errors.New("known_hosts callback is nil after reload")
}
return s.cb(hostname, remote, key)
}
func (s *acceptNewHostKeyState) reload() error {
callback, err := loadKnownHostsCallback(s.file)
if err != nil {
return err
}
s.cb = callback
return nil
}
func loadKnownHostsCallback(file string) (ssh.HostKeyCallback, error) {
_, err := os.Stat(file)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
return nil, err
}
return knownhosts.New(file)
}
func normalizeAcceptNewHostKeyOptions(options AcceptNewHostKeyOptions) AcceptNewHostKeyOptions {
options.KnownHostsFile = strings.TrimSpace(options.KnownHostsFile)
if options.FileMode == 0 {
options.FileMode = 0o600
}
return options
}
func buildAcceptNewKnownHostsLine(hostname string, remote net.Addr, key ssh.PublicKey, hashHosts bool, includeRemoteAddress bool) (string, error) {
addresses := collectKnownHostsAddresses(hostname, remote, includeRemoteAddress)
if len(addresses) == 0 {
return "", errors.New("no hostname or remote address available for known_hosts entry")
}
patterns := make([]string, 0, len(addresses))
for _, address := range addresses {
normalized := knownhosts.Normalize(address)
if hashHosts {
patterns = append(patterns, knownhosts.HashHostname(normalized))
continue
}
patterns = append(patterns, normalized)
}
authorizedKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))
return strings.Join(patterns, ",") + " " + authorizedKey, nil
}
func collectKnownHostsAddresses(hostname string, remote net.Addr, includeRemoteAddress bool) []string {
addresses := make([]string, 0, 2)
seen := make(map[string]struct{}, 2)
add := func(address string) {
address = strings.TrimSpace(address)
if address == "" {
return
}
normalized := knownhosts.Normalize(address)
if _, exists := seen[normalized]; exists {
return
}
seen[normalized] = struct{}{}
addresses = append(addresses, address)
}
add(hostname)
if includeRemoteAddress && remote != nil {
add(remote.String())
}
return addresses
}
func appendKnownHostsLine(file string, line string, mode os.FileMode) error {
if strings.TrimSpace(file) == "" {
return ErrKnownHostsFileRequired
}
if strings.TrimSpace(line) == "" {
return errors.New("known_hosts line is empty")
}
dir := filepath.Dir(file)
if dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o700); err != nil {
return err
}
}
handle, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_WRONLY, mode)
if err != nil {
return err
}
defer handle.Close()
if _, err := handle.WriteString(line + "\n"); err != nil {
return err
}
return handle.Chmod(mode)
}
func normalizeHostKeyFingerprint(raw string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", nil
}
if strings.HasPrefix(strings.ToUpper(value), "SHA256:") {
suffix := strings.TrimSpace(value[len("SHA256:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return "SHA256:" + suffix, nil
}
if strings.HasPrefix(strings.ToUpper(value), "MD5:") {
suffix := strings.TrimSpace(value[len("MD5:"):])
if suffix == "" {
return "", ErrHostFingerprintRequired
}
return normalizeMD5Fingerprint("MD5:" + suffix), nil
}
if strings.Count(value, ":") >= 2 {
return normalizeMD5Fingerprint("MD5:" + value), nil
}
return "SHA256:" + value, nil
}
func normalizeMD5Fingerprint(value string) string {
if !strings.HasPrefix(strings.ToUpper(value), "MD5:") {
return "MD5:" + strings.ToLower(value)
}
return "MD5:" + strings.ToLower(value[len("MD5:"):])
}