starssh/sshagent_windows.go
starainrt 0c23e7d4bf
feat: 增强 ssh-agent 认证与转发可靠性
- 拆分 ssh-agent 认证、连接与 endpoint 解析逻辑
- 新增 IdentityAgent、SSHAgentTimeout、SSHAgentForwardTimeout 和调试事件
- 为 agent list/sign 操作增加独立 deadline,避免硬件 agent 卡死登录
- 支持 agent signer 失败后跳过坏 key 并重试后续 key
- 优先处理 RSA-SHA2 签名,兼容现代 OpenSSH 认证要求
- 增强 agent forwarding 的探测、通道空闲超时和关闭清理
- 补充 Windows OpenSSH pipe 与 GPG S.gpg-agent.ssh socket 文件支持
- 增加相关回归测试和 Windows 编译验证覆盖
2026-05-27 13:10:35 +08:00

272 lines
7.4 KiB
Go

//go:build windows
package starssh
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
)
const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent`
var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file")
type gpgSocketInfo struct {
port uint16
nonce []byte
cygwin bool
}
func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) {
return resolvedSSHAgentEndpoint{
Endpoint: defaultWindowsSSHAgentPipe,
Source: "platform-default",
Network: "windows-pipe",
}, nil
}
func defaultSSHAgentNetwork(endpoint string) string {
if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok {
return "windows-pipe"
}
if isAgentSSHSocketPath(endpoint) {
return "gpg-socket"
}
return "unix"
}
func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) {
if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok {
return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default")
}
if isAgentSSHSocketPath(resolved.Endpoint) {
return dialWindowsGPGSocketFile(resolved.Endpoint, timeout)
}
return dialWindowsUnixAgent(resolved.Endpoint, timeout)
}
func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound)
}
func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) {
trimmed := strings.TrimSpace(endpoint)
if trimmed == "" {
return "", false
}
normalized := trimmed
if strings.HasPrefix(normalized, "//./pipe/") {
normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/")
}
if strings.HasPrefix(normalized, `\\.\pipe\`) {
return normalized, true
}
return "", false
}
func isWindowsPipeUnavailable(err error) bool {
return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND)
}
func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) {
if timeout > 0 {
return net.DialTimeout("unix", endpoint, timeout)
}
return net.Dial("unix", endpoint)
}
func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) {
ctx := context.Background()
cancel := func() {}
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
}
defer cancel()
return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0)
}
func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) {
if path == "" {
return nil, fmt.Errorf("gpg agent endpoint is empty")
}
if depth > 8 {
return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if target, ok := parseGPGAssuanSocketRedirect(data); ok {
target = resolveGPGSocketRedirectTarget(path, target)
if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok {
return dialWindowsNamedPipeContext(ctx, pipePath, false)
}
return dialWindowsGPGSocketFileDepth(ctx, target, depth+1)
}
info, err := parseGPGSocketInfo(path, data)
if err != nil {
return nil, err
}
return dialWindowsGPGSocketInfo(ctx, info)
}
func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) {
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port))))
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
_ = conn.Close()
return nil, err
}
}
if _, err := conn.Write(info.nonce); err != nil {
_ = conn.Close()
return nil, err
}
if info.cygwin {
var nonce [16]byte
if _, err := io.ReadFull(conn, nonce[:]); err != nil {
_ = conn.Close()
return nil, err
}
var credential [8]byte
binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid()))
if _, err := conn.Write(credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
if _, err := io.ReadFull(conn, credential[:]); err != nil {
_ = conn.Close()
return nil, err
}
}
_ = conn.SetDeadline(time.Time{})
return conn, nil
}
func resolveGPGSocketRedirectTarget(source string, target string) string {
target = strings.TrimSpace(target)
if target == "" || filepath.IsAbs(target) {
return target
}
if _, ok := normalizeWindowsSSHAgentPipe(target); ok {
return target
}
return filepath.Join(filepath.Dir(source), target)
}
func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) {
if info, ok := parseGPGAssuanSocketInfo(data); ok {
return info, nil
}
if info, ok := parseGPGCygwinSocketInfo(data); ok {
return info, nil
}
return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path)
}
func parseGPGAssuanSocketRedirect(data []byte) (string, bool) {
text := strings.ReplaceAll(string(data), "\r\n", "\n")
text = strings.TrimSuffix(text, "\n")
lines := strings.Split(text, "\n")
if len(lines) != 2 || lines[0] != "%Assuan%" {
return "", false
}
target, ok := strings.CutPrefix(lines[1], "socket=")
if !ok || strings.TrimSpace(target) == "" {
return "", false
}
return os.ExpandEnv(target), true
}
func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) {
newline := bytes.IndexByte(data, '\n')
if newline <= 0 || len(data)-newline-1 != 16 {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 16)
copy(nonce, data[newline+1:])
return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true
}
func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) {
if !bytes.HasPrefix(data, []byte("!<socket >")) {
return gpgSocketInfo{}, false
}
fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00"))
if len(fields) != 3 || fields[1] != "s" {
return gpgSocketInfo{}, false
}
port64, err := strconv.ParseUint(fields[0], 10, 16)
if err != nil || port64 == 0 {
return gpgSocketInfo{}, false
}
hexParts := strings.Split(fields[2], "-")
if len(hexParts) != 4 {
return gpgSocketInfo{}, false
}
nonce := make([]byte, 0, 16)
for _, part := range hexParts {
if len(part) != 8 {
return gpgSocketInfo{}, false
}
value, err := strconv.ParseUint(part, 16, 32)
if err != nil {
return gpgSocketInfo{}, false
}
var chunk [4]byte
binary.LittleEndian.PutUint32(chunk[:], uint32(value))
nonce = append(nonce, chunk[:]...)
}
return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true
}
func isAgentSSHSocketPath(endpoint string) bool {
normalized := strings.ToLower(strings.TrimSpace(endpoint))
return strings.HasSuffix(normalized, "s.gpg-agent.ssh")
}
func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
}
conn, err := winio.DialPipeContext(ctx, path)
if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) {
return nil, errSSHAgentUnavailable
}
if err != nil {
return nil, err
}
return conn, nil
}