feat: 增强 starssh 的 agent forwarding 与 tcp/unix 转发能力

- 为 LoginInput 增加 ForwardSSHAgent 配置,并在 Exec/PTTY 会话创建时按需自动请求 agent forwarding
- 新增 agent_forward 运行时,封装本地 ssh-agent 建连、转发注册、显式请求与 unavailable/denied 语义
- 自动 agent forwarding 改为 best-effort:本地 agent 不可用、转发被拒绝或初始化失败时不再打断会话创建
- 为 StarSSH 增加 closing 状态与 agent forwarder 生命周期回收,避免 Close 与会话创建并发时泄漏资源
- 扩展 ForwardRequest 为带网络归一化的转发模型,支持 tcp/tcp4/tcp6/unix 端点组合
- 新增本地/远端 tcp<->unix、unix<->unix 及 detached helper,补齐 streamlocal 场景下的常用 API
- 将显式网络地址编码收口为 tcp4://、tcp6://、unix://,消除 tcp:22 一类值的解析歧义
- 为本地 unix listener 增加 stale socket 探测、复用与关闭清理,避免遗留 socket 导致重启失败
- 补充 agent forwarding、关闭竞态、remote unix forward、local unix forward、stale socket 复用与端点解析等回归测试
This commit is contained in:
2026-04-26 20:27:10 +08:00
parent f20eb653ae
commit b29246a9c4
7 changed files with 1463 additions and 45 deletions
+282 -24
View File
@@ -3,21 +3,39 @@ package starssh
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
type ForwardRequest struct {
// Keep the exported shape compatible with older positional literals:
// ForwardRequest{listenAddr, targetAddr, dialContext}.
//
// Non-default networks can be encoded with an explicit scheme-like prefix:
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
// Bare values default to the "tcp" network.
ListenAddr string
TargetAddr string
DialContext DialContextFunc
}
type normalizedForwardRequest struct {
ListenNetwork string
ListenAddr string
TargetNetwork string
TargetAddr string
DialContext DialContextFunc
}
type DynamicForwardRequest struct {
ListenAddr string
}
@@ -41,10 +59,16 @@ type PortForwarder struct {
cleanupFns []func() error
}
const unixForwardProbeTimeout = 200 * time.Millisecond
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
return client.Dial(network, address)
}
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
return client.Listen(network, address)
}
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
if ctx == nil {
ctx = context.Background()
@@ -64,6 +88,90 @@ func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network strin
return s.dialTCPContext(ctx, network, address, s.Close)
}
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
@@ -136,21 +244,22 @@ func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error)
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@@ -159,14 +268,12 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
listener, err := net.Listen("tcp", req.ListenAddr)
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
@@ -174,15 +281,19 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
if cleanup != nil {
_ = cleanup()
}
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@@ -192,19 +303,17 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
if err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("remote forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("remote forward target address is empty")
}
listener, err := client.Listen("tcp", req.ListenAddr)
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
dialContext := req.DialContext
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
dialContext := normalizedReq.DialContext
if dialContext == nil {
dialer := &net.Dialer{
Timeout: defaultLoginTimeout,
@@ -214,7 +323,7 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
forwarder := newPortForwarder(listener)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return dialContext(ctx, "tcp", req.TargetAddr)
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@@ -239,6 +348,74 @@ func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder
return forwarder, nil
}
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
normalized := normalizedForwardRequest{
DialContext: req.DialContext,
}
var err error
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
if err != nil {
return normalized, fmt.Errorf("normalize listen address: %w", err)
}
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
if err != nil {
return normalized, fmt.Errorf("normalize target address: %w", err)
}
if strings.TrimSpace(normalized.ListenAddr) == "" {
return normalized, errors.New("forward listen address is empty")
}
if strings.TrimSpace(normalized.TargetAddr) == "" {
return normalized, errors.New("forward target address is empty")
}
return normalized, nil
}
func normalizeForwardNetwork(network string) string {
network = strings.ToLower(strings.TrimSpace(network))
if network == "" {
return "tcp"
}
return network
}
func isSupportedForwardNetwork(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6", "unix":
return true
default:
return false
}
}
func parseForwardEndpoint(value string) (network string, address string, err error) {
value = strings.TrimSpace(value)
if value == "" {
return "tcp", "", nil
}
lowerValue := strings.ToLower(value)
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
if strings.HasPrefix(lowerValue, prefix) {
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
address = value[len(prefix):]
if !isSupportedForwardNetwork(network) {
return "", "", fmt.Errorf("unsupported forward network %q", network)
}
return network, address, nil
}
}
return "tcp", value, nil
}
func forwardEndpoint(network string, address string) string {
network = normalizeForwardNetwork(network)
if network == "tcp" {
return address
}
return network + "://" + address
}
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
@@ -344,6 +521,87 @@ func (f *PortForwarder) addCleanup(fn func() error) {
f.cleanupFns = append(f.cleanupFns, fn)
}
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
network = normalizeForwardNetwork(network)
if network != "unix" {
listener, err := net.Listen(network, address)
return listener, nil, err
}
if err := removeStaleUnixSocket(address); err != nil {
return nil, nil, err
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, nil, err
}
cleanup, err := makeUnixSocketCleanup(address)
if err != nil {
_ = listener.Close()
_ = removeUnixSocketPath(address)
return nil, nil, err
}
return listener, cleanup, nil
}
func removeStaleUnixSocket(path string) error {
info, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if info.Mode()&os.ModeSocket == 0 {
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
}
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
if err == nil {
_ = conn.Close()
return fmt.Errorf("local unix forward path %q is already in use", path)
}
if !isStaleUnixSocketDialError(err) {
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
}
return removeUnixSocketPath(path)
}
func isStaleUnixSocketDialError(err error) bool {
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
}
func makeUnixSocketCleanup(path string) (func() error, error) {
info, err := os.Lstat(path)
if err != nil {
return nil, err
}
return func() error {
current, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
return nil
}
return removeUnixSocketPath(path)
}, nil
}
func removeUnixSocketPath(path string) error {
err := os.Remove(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
func (f *PortForwarder) runCleanup() {
if f == nil {
return