starssh/login.go
starainrt 0c23e7d4bf
feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
2026-05-27 13:10:35 +08:00

194 lines
4.5 KiB
Go

package starssh
import (
"context"
"encoding/base64"
"errors"
"net"
"os"
"time"
"golang.org/x/crypto/ssh"
)
var ErrHostKeyCallbackRequired = errors.New("host key callback is required; use DefaultAllowHostKeyCallback to explicitly allow any host key")
func DefaultAllowHostKeyCallback(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
}
func LoginContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
return loginWithContext(ctx, info)
}
func Login(info LoginInput) (*StarSSH, error) {
return LoginContext(context.Background(), info)
}
func loginWithContext(ctx context.Context, info LoginInput) (*StarSSH, error) {
info = normalizeLoginInput(info)
if info.HostKeyCallback == nil {
return nil, ErrHostKeyCallbackRequired
}
authTimeout := effectiveLoginTimeout(info)
loginCtx, cancel := contextWithLoginTimeout(ctx, authTimeout)
defer cancel()
order, err := normalizeAuthOrder(info.AuthOrder)
if err != nil {
return nil, err
}
if shouldRetrySSHAgentAuth(info, order) {
agentAttempt := newSSHAgentAuthAttempt()
for {
agentAttempt.begin()
sshInfo, err := loginOnceWithContext(loginCtx, info, authTimeout, agentAttempt)
if err == nil {
return sshInfo, nil
}
if errors.Is(err, errRetrySSHAgentAuth) && loginCtx.Err() == nil {
continue
}
return sshInfo, err
}
}
return loginOnceWithContext(loginCtx, info, authTimeout, nil)
}
func loginOnceWithContext(ctx context.Context, info LoginInput, authTimeout time.Duration, agentAttempt *sshAgentAuthAttempt) (*StarSSH, error) {
sshInfo := &StarSSH{
LoginInfo: info,
}
auth, authCleanup, err := buildAuthMethodsWithAgentAttempt(info, agentAttempt)
if err != nil {
return nil, err
}
if authCleanup != nil {
defer authCleanup()
}
hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
sshInfo.PublicKey = key
sshInfo.RemoteAddr = remote
sshInfo.Hostname = hostname
return info.HostKeyCallback(hostname, remote, key)
}
bannerCallback := func(banner string) error {
sshInfo.Banner = banner
if info.BannerCallback != nil {
return info.BannerCallback(banner)
}
return nil
}
clientConfig := &ssh.ClientConfig{
User: info.User,
Auth: auth,
Timeout: authTimeout,
HostKeyCallback: hostKeyCallback,
BannerCallback: bannerCallback,
}
if len(info.Ciphers) > 0 || len(info.MACs) > 0 || len(info.KeyExchanges) > 0 {
clientConfig.Config = ssh.Config{
Ciphers: info.Ciphers,
MACs: info.MACs,
KeyExchanges: info.KeyExchanges,
}
}
targetAddr := joinHostPort(info.Addr, info.Port)
rawConn, upstream, err := dialTargetConn(ctx, info)
if err != nil {
return sshInfo, err
}
restoreDeadline := applyConnDeadline(rawConn, ctx, authTimeout)
defer restoreDeadline()
clientConn, chans, reqs, err := ssh.NewClientConn(rawConn, targetAddr, clientConfig)
if err != nil {
_ = rawConn.Close()
if upstream != nil {
_ = upstream.Close()
}
return sshInfo, err
}
client := ssh.NewClient(clientConn, chans, reqs)
sshInfo.setTransport(client, upstream)
if sshInfo.PublicKey != nil {
sshInfo.PubkeyBase64 = base64.StdEncoding.EncodeToString(sshInfo.PublicKey.Marshal())
}
sshInfo.startAutoKeepAlive()
return sshInfo, nil
}
func contextWithLoginTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
if timeout <= 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, timeout)
}
func LoginSimple(host string, user string, passwd string, prikeyPath string, port int, timeout time.Duration) (*StarSSH, error) {
info := LoginInput{
Addr: host,
Port: port,
Timeout: timeout,
DialTimeout: timeout,
User: user,
HostKeyCallback: DefaultAllowHostKeyCallback,
}
if prikeyPath != "" {
prikey, err := os.ReadFile(prikeyPath)
if err != nil {
return nil, err
}
info.Prikey = string(prikey)
if passwd != "" {
info.PrikeyPwd = passwd
}
} else {
info.Password = passwd
}
return Login(info)
}
func normalizeLoginInput(info LoginInput) LoginInput {
if info.Port <= 0 {
info.Port = defaultSSHPort
}
return info
}
func effectiveLoginTimeout(info LoginInput) time.Duration {
if info.Timeout <= 0 {
return 0
}
return info.Timeout
}
func effectiveDialTimeout(info LoginInput) time.Duration {
switch {
case info.DialTimeout < 0:
return 0
case info.DialTimeout > 0:
return info.DialTimeout
case info.Timeout > 0:
return info.Timeout
default:
return defaultLoginTimeout
}
}