starssh/login.go
starainrt f20eb653ae
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力
- 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现
  - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持
  - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行
  - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令
  - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理
  - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景
  - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径
  - 围绕新的 session/连接生命周期重做执行池与运行时支撑
  - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义
  - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00

363 lines
8.5 KiB
Go

package starssh
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
var errSSHAgentUnavailable = errors.New("ssh-agent unavailable")
var defaultAuthOrder = []AuthMethodKind{
AuthMethodSSHAgent,
AuthMethodPrivateKey,
AuthMethodPassword,
AuthMethodKeyboardInteractive,
}
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
}
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return loginWithContext(ctx, info)
}
func Login(info LoginInput) (*StarSSH, error) {
return LoginContext(context.Background(), info)
}
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
info = normalizeLoginInput(info)
if info.HostKeyCallback == nil {
return nil, ErrHostKeyCallbackRequired
}
loginCtx, cancel := contextWithLoginTimeout(ctx, info.Timeout)
defer cancel()
sshInfo := &StarSSH{
LoginInfo: info,
}
auth, authCleanup, err := buildAuthMethods(info)
if err != nil {
return nil, err
}
if authCleanup != nil {
defer authCleanup()
}
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
return info.HostKeyCallback(hostname, remote, key)
}
bannerCallback := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig := &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: info.Timeout,
HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback,
}
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
clientConfig.Config = ssh.Config{
Ciphers: info.Ciphers,
MACs: info.MACs,
KeyExchanges: info.KeyExchanges,
}
}
targetAddr := joinHostPort(info.Addr, info.Port)
rawConn, upstream, err := dialTargetConn(loginCtx, info)
if err != nil {
return sshInfo, err
}
restoreDeadline := applyConnDeadline(rawConn, loginCtx, info.Timeout)
defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
if err != nil {
_ = rawConn.Close()
if upstream != nil {
_ = upstream.Close()
}
return sshInfo, err
}
client := ssh.NewClient(clientConn, chans, reqs)
sshInfo.setTransport(client, upstream)
if sshInfo.PublicKey != nil {
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
sshInfo.startAutoKeepAlive()
return sshInfo, nil
}
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
if timeout <= 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, timeout)
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
info := LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
User: user,
HostKeyCallback: DefaultAllowHostKeyCallback,
}
if prikeyPath != "" {
prikey, err := os.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 {
info.Port = defaultSSHPort
}
if info.Timeout <= 0 {
info.Timeout = defaultLoginTimeout
}
return info
}
func buildAuthMethods(info LoginInput) ([]ssh.AuthMethod, func(), error) {
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, nil, err
}
auth := make([]ssh.AuthMethod, 0, len(order))
var agentErr error
var cleanupFuncs []func()
for _, methodKind := range order {
switch methodKind {
case AuthMethodPrivateKey:
method, err := buildPrivateKeyAuthMethod(info)
if err != nil {
return nil, nil, err
}
if method != nil {
auth = append(auth, method)
}
case AuthMethodPassword:
method := buildPasswordAuthMethod(info.Password, info.PasswordCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodKeyboardInteractive:
method := buildKeyboardInteractiveAuthMethod(info.Password, info.PasswordCallback, info.KeyboardInteractiveCallback)
if method != nil {
auth = append(auth, method)
}
case AuthMethodSSHAgent:
if info.DisableSSHAgent {
continue
}
agentMethod, cleanup, err := buildSSHAgentAuthMethod(info.Timeout)
if err != nil {
agentErr = err
continue
}
if agentMethod != nil {
auth = append(auth, agentMethod)
}
if cleanup != nil {
cleanupFuncs = append(cleanupFuncs, cleanup)
}
}
}
if len(auth) == 0 {
if agentErr != nil {
return nil, nil, fmt.Errorf("no authentication method provided; ssh-agent unavailable: %w", agentErr)
}
return nil, nil, errors.New("no authentication method provided: password, private key, or ssh-agent is required")
}
return auth, composeCleanup(cleanupFuncs...), nil
}
func normalizeAuthOrder(order []AuthMethodKind) ([]AuthMethodKind, error) {
if len(order) == 0 {
return append([]AuthMethodKind(nil), defaultAuthOrder...), nil
}
normalized := make([]AuthMethodKind, 0, len(order))
seen := make(map[AuthMethodKind]struct{}, len(order))
for _, raw := range order {
kind := AuthMethodKind(strings.ToLower(strings.TrimSpace(string(raw))))
if kind == "" {
return nil, errors.New("auth order contains an empty auth method")
}
if !isSupportedAuthMethodKind(kind) {
return nil, fmt.Errorf("unsupported auth method %q", raw)
}
if _, exists := seen[kind]; exists {
continue
}
seen[kind] = struct{}{}
normalized = append(normalized, kind)
}
if len(normalized) == 0 {
return nil, errors.New("auth order is empty")
}
return normalized, nil
}
func isSupportedAuthMethodKind(kind AuthMethodKind) bool {
switch kind {
case AuthMethodPrivateKey, AuthMethodPassword, AuthMethodKeyboardInteractive, AuthMethodSSHAgent:
return true
default:
return false
}
}
func buildPrivateKeyAuthMethod(info LoginInput) (ssh.AuthMethod, error) {
if strings.TrimSpace(info.Prikey) == "" {
return nil, nil
}
pemBytes := []byte(info.Prikey)
if info.PrikeyPwd == "" {
signer, err := ssh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(info.PrikeyPwd))
if err != nil {
return nil, err
}
return ssh.PublicKeys(signer), nil
}
func buildPasswordAuthMethod(password string, callback func() (string, error)) ssh.AuthMethod {
if password != "" {
return ssh.Password(password)
}
if callback != nil {
return ssh.PasswordCallback(callback)
}
return nil
}
func buildKeyboardInteractiveAuthMethod(
password string,
passwordCallback func() (string, error),
challenge ssh.KeyboardInteractiveChallenge,
) ssh.AuthMethod {
if challenge != nil {
return ssh.KeyboardInteractive(challenge)
}
if password == "" && passwordCallback == nil {
return nil
}
keyboardInteractiveChallenge := func(user, instruction string, questions []string, echos []bool) ([]string, error) {
if len(questions) == 0 {
return []string{}, nil
}
answer := password
if answer == "" {
var err error
answer, err = passwordCallback()
if err != nil {
return nil, err
}
}
answers := make([]string, len(questions))
for i := range questions {
answers[i] = answer
}
return answers, nil
}
return ssh.KeyboardInteractive(keyboardInteractiveChallenge)
}
func buildSSHAgentAuthMethod(timeout time.Duration) (ssh.AuthMethod, func(), error) {
conn, err := dialSSHAgent(timeout)
if err != nil {
if errors.Is(err, errSSHAgentUnavailable) {
return nil, nil, nil
}
return nil, nil, err
}
if conn == nil {
return nil, nil, nil
}
signers, err := agent.NewClient(conn).Signers()
if err != nil {
_ = conn.Close()
return nil, nil, err
}
if len(signers) == 0 {
_ = conn.Close()
return nil, nil, errors.New("ssh-agent has no loaded keys")
}
return ssh.PublicKeys(signers...), func() {
_ = conn.Close()
}, nil
}
func composeCleanup(funcs ...func()) func() {
if len(funcs) == 0 {
return nil
}
return func() {
for i := len(funcs) - 1; i >= 0; i-- {
if funcs[i] != nil {
funcs[i]()
}
}
}
}