feat: 增强 starssh 的 agent forwarding 与 tcp/unix 转发能力
- 为 LoginInput 增加 ForwardSSHAgent 配置,并在 Exec/PTTY 会话创建时按需自动请求 agent forwarding - 新增 agent_forward 运行时,封装本地 ssh-agent 建连、转发注册、显式请求与 unavailable/denied 语义 - 自动 agent forwarding 改为 best-effort:本地 agent 不可用、转发被拒绝或初始化失败时不再打断会话创建 - 为 StarSSH 增加 closing 状态与 agent forwarder 生命周期回收,避免 Close 与会话创建并发时泄漏资源 - 扩展 ForwardRequest 为带网络归一化的转发模型,支持 tcp/tcp4/tcp6/unix 端点组合 - 新增本地/远端 tcp<->unix、unix<->unix 及 detached helper,补齐 streamlocal 场景下的常用 API - 将显式网络地址编码收口为 tcp4://、tcp6://、unix://,消除 tcp:22 一类值的解析歧义 - 为本地 unix listener 增加 stale socket 探测、复用与关闭清理,避免遗留 socket 导致重启失败 - 补充 agent forwarding、关闭竞态、remote unix forward、local unix forward、stale socket 复用与端点解析等回归测试
This commit is contained in:
+282
-24
@@ -3,21 +3,39 @@ package starssh
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ForwardRequest struct {
|
||||
// Keep the exported shape compatible with older positional literals:
|
||||
// ForwardRequest{listenAddr, targetAddr, dialContext}.
|
||||
//
|
||||
// Non-default networks can be encoded with an explicit scheme-like prefix:
|
||||
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
|
||||
// Bare values default to the "tcp" network.
|
||||
ListenAddr string
|
||||
TargetAddr string
|
||||
DialContext DialContextFunc
|
||||
}
|
||||
|
||||
type normalizedForwardRequest struct {
|
||||
ListenNetwork string
|
||||
ListenAddr string
|
||||
TargetNetwork string
|
||||
TargetAddr string
|
||||
DialContext DialContextFunc
|
||||
}
|
||||
|
||||
type DynamicForwardRequest struct {
|
||||
ListenAddr string
|
||||
}
|
||||
@@ -41,10 +59,16 @@ type PortForwarder struct {
|
||||
cleanupFns []func() error
|
||||
}
|
||||
|
||||
const unixForwardProbeTimeout = 200 * time.Millisecond
|
||||
|
||||
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||||
return client.Dial(network, address)
|
||||
}
|
||||
|
||||
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
||||
return client.Listen(network, address)
|
||||
}
|
||||
|
||||
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -64,6 +88,90 @@ func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network strin
|
||||
return s.dialTCPContext(ctx, network, address, s.Close)
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartLocalForwardDetached(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: listenAddr,
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: targetAddr,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
||||
return s.StartRemoteForward(ForwardRequest{
|
||||
ListenAddr: forwardEndpoint("unix", listenPath),
|
||||
TargetAddr: forwardEndpoint("unix", targetPath),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -136,21 +244,22 @@ func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error)
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
|
||||
return nil, errors.New("local forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("local forward target address is empty")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(cleanup)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
||||
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
@@ -159,14 +268,12 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("local forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("local forward target address is empty")
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", req.ListenAddr)
|
||||
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -174,15 +281,19 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
|
||||
forwardClient, err := s.newForwardDialClient(context.Background())
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
if cleanup != nil {
|
||||
_ = cleanup()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.addCleanup(cleanup)
|
||||
forwarder.addCleanup(func() error {
|
||||
return normalizeAlreadyClosedError(forwardClient.Close())
|
||||
})
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
|
||||
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
@@ -192,19 +303,17 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.ListenAddr) == "" {
|
||||
return nil, errors.New("remote forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.TargetAddr) == "" {
|
||||
return nil, errors.New("remote forward target address is empty")
|
||||
}
|
||||
|
||||
listener, err := client.Listen("tcp", req.ListenAddr)
|
||||
normalizedReq, err := normalizeForwardRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := req.DialContext
|
||||
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := normalizedReq.DialContext
|
||||
if dialContext == nil {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultLoginTimeout,
|
||||
@@ -214,7 +323,7 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
|
||||
|
||||
forwarder := newPortForwarder(listener)
|
||||
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
||||
return dialContext(ctx, "tcp", req.TargetAddr)
|
||||
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
||||
})
|
||||
return forwarder, nil
|
||||
}
|
||||
@@ -239,6 +348,74 @@ func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
|
||||
normalized := normalizedForwardRequest{
|
||||
DialContext: req.DialContext,
|
||||
}
|
||||
|
||||
var err error
|
||||
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
|
||||
if err != nil {
|
||||
return normalized, fmt.Errorf("normalize listen address: %w", err)
|
||||
}
|
||||
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
|
||||
if err != nil {
|
||||
return normalized, fmt.Errorf("normalize target address: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(normalized.ListenAddr) == "" {
|
||||
return normalized, errors.New("forward listen address is empty")
|
||||
}
|
||||
if strings.TrimSpace(normalized.TargetAddr) == "" {
|
||||
return normalized, errors.New("forward target address is empty")
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
func normalizeForwardNetwork(network string) string {
|
||||
network = strings.ToLower(strings.TrimSpace(network))
|
||||
if network == "" {
|
||||
return "tcp"
|
||||
}
|
||||
return network
|
||||
}
|
||||
|
||||
func isSupportedForwardNetwork(network string) bool {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6", "unix":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseForwardEndpoint(value string) (network string, address string, err error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "tcp", "", nil
|
||||
}
|
||||
|
||||
lowerValue := strings.ToLower(value)
|
||||
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
|
||||
if strings.HasPrefix(lowerValue, prefix) {
|
||||
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
|
||||
address = value[len(prefix):]
|
||||
if !isSupportedForwardNetwork(network) {
|
||||
return "", "", fmt.Errorf("unsupported forward network %q", network)
|
||||
}
|
||||
return network, address, nil
|
||||
}
|
||||
}
|
||||
return "tcp", value, nil
|
||||
}
|
||||
|
||||
func forwardEndpoint(network string, address string) string {
|
||||
network = normalizeForwardNetwork(network)
|
||||
if network == "tcp" {
|
||||
return address
|
||||
}
|
||||
return network + "://" + address
|
||||
}
|
||||
|
||||
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
||||
if _, err := s.requireSSHClient(); err != nil {
|
||||
return nil, err
|
||||
@@ -344,6 +521,87 @@ func (f *PortForwarder) addCleanup(fn func() error) {
|
||||
f.cleanupFns = append(f.cleanupFns, fn)
|
||||
}
|
||||
|
||||
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
|
||||
network = normalizeForwardNetwork(network)
|
||||
if network != "unix" {
|
||||
listener, err := net.Listen(network, address)
|
||||
return listener, nil, err
|
||||
}
|
||||
|
||||
if err := removeStaleUnixSocket(address); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup, err := makeUnixSocketCleanup(address)
|
||||
if err != nil {
|
||||
_ = listener.Close()
|
||||
_ = removeUnixSocketPath(address)
|
||||
return nil, nil, err
|
||||
}
|
||||
return listener, cleanup, nil
|
||||
}
|
||||
|
||||
func removeStaleUnixSocket(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode()&os.ModeSocket == 0 {
|
||||
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
|
||||
if err == nil {
|
||||
_ = conn.Close()
|
||||
return fmt.Errorf("local unix forward path %q is already in use", path)
|
||||
}
|
||||
if !isStaleUnixSocketDialError(err) {
|
||||
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
|
||||
}
|
||||
return removeUnixSocketPath(path)
|
||||
}
|
||||
|
||||
func isStaleUnixSocketDialError(err error) bool {
|
||||
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
|
||||
}
|
||||
|
||||
func makeUnixSocketCleanup(path string) (func() error, error) {
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func() error {
|
||||
current, err := os.Lstat(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
|
||||
return nil
|
||||
}
|
||||
return removeUnixSocketPath(path)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func removeUnixSocketPath(path string) error {
|
||||
err := os.Remove(path)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (f *PortForwarder) runCleanup() {
|
||||
if f == nil {
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user