starssh/agent_forward.go
starainrt 1625997d8f
fix: 拆分 starssh 的拨号超时与认证超时语义
- 为 LoginInput 新增 DialTimeout,明确区分【TCP/proxy/ssh-agent 拨号超时】和【SSH 握手/认证超时】
- 将 Timeout 收口为握手/认证阶段超时,0 表示不限制,不再在登录入口自动回填默认值
- 新增 effectiveLoginTimeout/effectiveDialTimeout,统一超时决策逻辑
- 调整 login 流程,仅对 login context、ssh.ClientConfig 和握手阶段连接 deadline 使用认证超时
- 调整 transport 拨号链路,默认 TCP dial、proxy dial 与 ssh-agent 建连统一改用 DialTimeout
- 修正 agent forwarding 初始化仍错误复用 LoginInfo.Timeout 的问题
- 保持 LoginSimple 的直观行为:传入 timeout 时同时映射到 Timeout 和 DialTimeout
- 新增 login_timeout_test,覆盖零值不回填、DialTimeout 优先级,以及 ssh-agent 认证路径使用拨号超时的回归测试
2026-04-26 23:29:36 +08:00

152 lines
3.6 KiB
Go

package starssh
import (
"errors"
"fmt"
"io"
"strings"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var requestSSHAgentForwarding = func(session *ssh.Session) error {
return sshagent.RequestAgentForwarding(session)
}
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
return sshagent.ForwardToAgent(client, keyring)
}
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
conn, err := dialSSHAgent(timeout)
if err != nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
}
if conn == nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
}
return sshagent.NewClient(conn), conn, nil
}
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
if s == nil {
return errors.New("ssh client is nil")
}
if session == nil {
return errors.New("ssh session is nil")
}
if err := s.ensureAgentForwarding(); err != nil {
return err
}
if err := requestSSHAgentForwarding(session); err != nil {
if isSSHAgentForwardingDeniedError(err) {
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
}
return err
}
return nil
}
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
if s == nil || !s.LoginInfo.ForwardSSHAgent {
return nil
}
err := s.RequestAgentForwarding(session)
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
return nil
}
return err
}
func (s *StarSSH) ensureAgentForwarding() error {
if s == nil {
return errors.New("ssh client is nil")
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
if s.agentForwarder != nil {
return nil
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
keyring, closer, err := newSSHAgentForwarder(effectiveDialTimeout(s.LoginInfo))
if err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if s.closing.Load() {
_ = closer.Close()
return errSSHClientClosing
}
if err := routeSSHAgentForwarding(client, keyring); err != nil {
_ = closer.Close()
return err
}
if !s.canAttachAgentForwarder(client) {
_ = closer.Close()
return errSSHClientClosing
}
s.agentForwarder = closer
return nil
}
func (s *StarSSH) takeAgentForwarder() io.Closer {
if s == nil {
return nil
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
closer := s.agentForwarder
s.agentForwarder = nil
return closer
}
func isSSHAgentForwardingDeniedError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingDenied) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "forwarding request denied") ||
strings.Contains(message, "agent forwarding disabled")
}
func isSSHAgentForwardingUnavailableError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
strings.Contains(message, "ssh-agent unavailable")
}
func wrapSSHAgentForwardingUnavailable(err error) error {
if err == nil {
return nil
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return err
}
if errors.Is(err, errSSHAgentUnavailable) {
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
}
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
}