//go:build windows package starssh import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "net" "os" "path/filepath" "strconv" "strings" "time" "github.com/Microsoft/go-winio" "golang.org/x/sys/windows" ) const defaultWindowsSSHAgentPipe = `\\.\pipe\openssh-ssh-agent` var errInvalidGPGSocketInfo = errors.New("invalid gpg agent socket file") type gpgSocketInfo struct { port uint16 nonce []byte cygwin bool } func defaultSSHAgentEndpoint() (resolvedSSHAgentEndpoint, error) { return resolvedSSHAgentEndpoint{ Endpoint: defaultWindowsSSHAgentPipe, Source: "platform-default", Network: "windows-pipe", }, nil } func defaultSSHAgentNetwork(endpoint string) string { if _, ok := normalizeWindowsSSHAgentPipe(endpoint); ok { return "windows-pipe" } if isAgentSSHSocketPath(endpoint) { return "gpg-socket" } return "unix" } func dialResolvedSSHAgent(resolved resolvedSSHAgentEndpoint, timeout time.Duration) (net.Conn, error) { if pipePath, ok := normalizeWindowsSSHAgentPipe(resolved.Endpoint); ok { return dialWindowsNamedPipe(pipePath, timeout, resolved.Source == "platform-default") } if isAgentSSHSocketPath(resolved.Endpoint) { return dialWindowsGPGSocketFile(resolved.Endpoint, timeout) } return dialWindowsUnixAgent(resolved.Endpoint, timeout) } func dialWindowsNamedPipe(path string, timeout time.Duration, unavailableOnNotFound bool) (net.Conn, error) { ctx := context.Background() cancel := func() {} if timeout > 0 { ctx, cancel = context.WithTimeout(ctx, timeout) } defer cancel() return dialWindowsNamedPipeContext(ctx, path, unavailableOnNotFound) } func normalizeWindowsSSHAgentPipe(endpoint string) (string, bool) { trimmed := strings.TrimSpace(endpoint) if trimmed == "" { return "", false } normalized := trimmed if strings.HasPrefix(normalized, "//./pipe/") { normalized = `\\.\pipe\` + strings.TrimPrefix(normalized, "//./pipe/") } if strings.HasPrefix(normalized, `\\.\pipe\`) { return normalized, true } return "", false } func isWindowsPipeUnavailable(err error) bool { return errors.Is(err, windows.ERROR_FILE_NOT_FOUND) || errors.Is(err, windows.ERROR_PATH_NOT_FOUND) } func dialWindowsUnixAgent(endpoint string, timeout time.Duration) (net.Conn, error) { if timeout > 0 { return net.DialTimeout("unix", endpoint, timeout) } return net.Dial("unix", endpoint) } func dialWindowsGPGSocketFile(path string, timeout time.Duration) (net.Conn, error) { ctx := context.Background() cancel := func() {} if timeout > 0 { ctx, cancel = context.WithTimeout(ctx, timeout) } defer cancel() return dialWindowsGPGSocketFileDepth(ctx, strings.TrimSpace(path), 0) } func dialWindowsGPGSocketFileDepth(ctx context.Context, path string, depth int) (net.Conn, error) { if path == "" { return nil, fmt.Errorf("gpg agent endpoint is empty") } if depth > 8 { return nil, fmt.Errorf("gpg agent socket redirect loop at %s", path) } data, err := os.ReadFile(path) if err != nil { return nil, err } if target, ok := parseGPGAssuanSocketRedirect(data); ok { target = resolveGPGSocketRedirectTarget(path, target) if pipePath, ok := normalizeWindowsSSHAgentPipe(target); ok { return dialWindowsNamedPipeContext(ctx, pipePath, false) } return dialWindowsGPGSocketFileDepth(ctx, target, depth+1) } info, err := parseGPGSocketInfo(path, data) if err != nil { return nil, err } return dialWindowsGPGSocketInfo(ctx, info) } func dialWindowsGPGSocketInfo(ctx context.Context, info gpgSocketInfo) (net.Conn, error) { var dialer net.Dialer conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(int(info.port)))) if err != nil { return nil, err } if deadline, ok := ctx.Deadline(); ok { if err := conn.SetDeadline(deadline); err != nil { _ = conn.Close() return nil, err } } if _, err := conn.Write(info.nonce); err != nil { _ = conn.Close() return nil, err } if info.cygwin { var nonce [16]byte if _, err := io.ReadFull(conn, nonce[:]); err != nil { _ = conn.Close() return nil, err } var credential [8]byte binary.LittleEndian.PutUint32(credential[:4], uint32(os.Getpid())) if _, err := conn.Write(credential[:]); err != nil { _ = conn.Close() return nil, err } if _, err := io.ReadFull(conn, credential[:]); err != nil { _ = conn.Close() return nil, err } } _ = conn.SetDeadline(time.Time{}) return conn, nil } func resolveGPGSocketRedirectTarget(source string, target string) string { target = strings.TrimSpace(target) if target == "" || filepath.IsAbs(target) { return target } if _, ok := normalizeWindowsSSHAgentPipe(target); ok { return target } return filepath.Join(filepath.Dir(source), target) } func parseGPGSocketInfo(path string, data []byte) (gpgSocketInfo, error) { if info, ok := parseGPGAssuanSocketInfo(data); ok { return info, nil } if info, ok := parseGPGCygwinSocketInfo(data); ok { return info, nil } return gpgSocketInfo{}, fmt.Errorf("%w %s: expected GnuPG port/nonce socket file; if SSH_AUTH_SOCK was set to this file, restart gpg-agent to recreate it", errInvalidGPGSocketInfo, path) } func parseGPGAssuanSocketRedirect(data []byte) (string, bool) { text := strings.ReplaceAll(string(data), "\r\n", "\n") text = strings.TrimSuffix(text, "\n") lines := strings.Split(text, "\n") if len(lines) != 2 || lines[0] != "%Assuan%" { return "", false } target, ok := strings.CutPrefix(lines[1], "socket=") if !ok || strings.TrimSpace(target) == "" { return "", false } return os.ExpandEnv(target), true } func parseGPGAssuanSocketInfo(data []byte) (gpgSocketInfo, bool) { newline := bytes.IndexByte(data, '\n') if newline <= 0 || len(data)-newline-1 != 16 { return gpgSocketInfo{}, false } port64, err := strconv.ParseUint(strings.TrimSpace(string(data[:newline])), 10, 16) if err != nil || port64 == 0 { return gpgSocketInfo{}, false } nonce := make([]byte, 16) copy(nonce, data[newline+1:]) return gpgSocketInfo{port: uint16(port64), nonce: nonce}, true } func parseGPGCygwinSocketInfo(data []byte) (gpgSocketInfo, bool) { if !bytes.HasPrefix(data, []byte("!")) { return gpgSocketInfo{}, false } fields := strings.Fields(strings.TrimRight(string(data[10:]), "\x00")) if len(fields) != 3 || fields[1] != "s" { return gpgSocketInfo{}, false } port64, err := strconv.ParseUint(fields[0], 10, 16) if err != nil || port64 == 0 { return gpgSocketInfo{}, false } hexParts := strings.Split(fields[2], "-") if len(hexParts) != 4 { return gpgSocketInfo{}, false } nonce := make([]byte, 0, 16) for _, part := range hexParts { if len(part) != 8 { return gpgSocketInfo{}, false } value, err := strconv.ParseUint(part, 16, 32) if err != nil { return gpgSocketInfo{}, false } var chunk [4]byte binary.LittleEndian.PutUint32(chunk[:], uint32(value)) nonce = append(nonce, chunk[:]...) } return gpgSocketInfo{port: uint16(port64), nonce: nonce, cygwin: true}, true } func isAgentSSHSocketPath(endpoint string) bool { normalized := strings.ToLower(strings.TrimSpace(endpoint)) return strings.HasSuffix(normalized, "s.gpg-agent.ssh") } func dialWindowsNamedPipeContext(ctx context.Context, path string, unavailableOnNotFound bool) (net.Conn, error) { if ctx == nil { ctx = context.Background() } conn, err := winio.DialPipeContext(ctx, path) if err != nil && unavailableOnNotFound && isWindowsPipeUnavailable(err) { return nil, errSSHAgentUnavailable } if err != nil { return nil, err } return conn, nil }