2026-04-26 10:45:39 +08:00
|
|
|
package starssh
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
2026-04-26 20:27:10 +08:00
|
|
|
"fmt"
|
2026-04-26 10:45:39 +08:00
|
|
|
"io"
|
|
|
|
|
"net"
|
2026-04-26 20:27:10 +08:00
|
|
|
"os"
|
2026-04-26 10:45:39 +08:00
|
|
|
"strconv"
|
|
|
|
|
"strings"
|
|
|
|
|
"sync"
|
2026-04-26 20:27:10 +08:00
|
|
|
"syscall"
|
|
|
|
|
"time"
|
2026-04-26 10:45:39 +08:00
|
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type ForwardRequest struct {
|
2026-04-26 20:27:10 +08:00
|
|
|
// Keep the exported shape compatible with older positional literals:
|
|
|
|
|
// ForwardRequest{listenAddr, targetAddr, dialContext}.
|
|
|
|
|
//
|
|
|
|
|
// Non-default networks can be encoded with an explicit scheme-like prefix:
|
|
|
|
|
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
|
|
|
|
|
// Bare values default to the "tcp" network.
|
2026-04-26 10:45:39 +08:00
|
|
|
ListenAddr string
|
|
|
|
|
TargetAddr string
|
|
|
|
|
DialContext DialContextFunc
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
type normalizedForwardRequest struct {
|
|
|
|
|
ListenNetwork string
|
|
|
|
|
ListenAddr string
|
|
|
|
|
TargetNetwork string
|
|
|
|
|
TargetAddr string
|
|
|
|
|
DialContext DialContextFunc
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
type DynamicForwardRequest struct {
|
|
|
|
|
ListenAddr string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type PortForwarder struct {
|
|
|
|
|
listener net.Listener
|
|
|
|
|
ctx context.Context
|
|
|
|
|
cancel context.CancelFunc
|
|
|
|
|
acceptDone chan struct{}
|
|
|
|
|
|
|
|
|
|
connWG sync.WaitGroup
|
|
|
|
|
closeOnce sync.Once
|
|
|
|
|
|
|
|
|
|
connMu sync.Mutex
|
|
|
|
|
conns map[net.Conn]struct{}
|
|
|
|
|
|
|
|
|
|
errMu sync.Mutex
|
|
|
|
|
err error
|
|
|
|
|
|
|
|
|
|
cleanupOnce sync.Once
|
|
|
|
|
cleanupFns []func() error
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
const unixForwardProbeTimeout = 200 * time.Millisecond
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
|
|
|
|
return client.Dial(network, address)
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
|
|
|
|
|
return client.Listen(network, address)
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
}
|
|
|
|
|
return LoginContext(ctx, input)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) {
|
|
|
|
|
return s.DialTCPContext(context.Background(), network, address)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
|
|
|
return s.dialTCPContext(ctx, network, address, nil)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
|
|
|
return s.dialTCPContext(ctx, network, address, s.Close)
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForward(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForwardDetached(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForward(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForwardDetached(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForward(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForwardDetached(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForward(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartLocalForwardDetached(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartRemoteForward(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartRemoteForward(ForwardRequest{
|
|
|
|
|
ListenAddr: listenAddr,
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartRemoteForward(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: targetAddr,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
|
|
|
|
|
return s.StartRemoteForward(ForwardRequest{
|
|
|
|
|
ListenAddr: forwardEndpoint("unix", listenPath),
|
|
|
|
|
TargetAddr: forwardEndpoint("unix", targetPath),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(network) == "" {
|
|
|
|
|
network = "tcp"
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(address) == "" {
|
|
|
|
|
return nil, errors.New("forward address is empty")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type dialResult struct {
|
|
|
|
|
conn net.Conn
|
|
|
|
|
err error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
client, err := s.requireSSHClient()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
runCancel := func() {}
|
|
|
|
|
if onCancel != nil {
|
|
|
|
|
var cancelOnce sync.Once
|
|
|
|
|
runCancel = func() {
|
|
|
|
|
cancelOnce.Do(func() {
|
|
|
|
|
_ = onCancel()
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cancelDone := make(chan struct{})
|
|
|
|
|
defer close(cancelDone)
|
|
|
|
|
go func() {
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
runCancel()
|
|
|
|
|
case <-cancelDone:
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dialFunc := dialSSHClient
|
|
|
|
|
resultCh := make(chan dialResult, 1)
|
|
|
|
|
go func() {
|
|
|
|
|
conn, err := dialFunc(ctx, client, network, address)
|
|
|
|
|
if ctx.Err() != nil && conn != nil {
|
|
|
|
|
_ = conn.Close()
|
|
|
|
|
conn = nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case resultCh <- dialResult{conn: conn, err: err}:
|
|
|
|
|
default:
|
|
|
|
|
if conn != nil {
|
|
|
|
|
_ = conn.Close()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case result := <-resultCh:
|
|
|
|
|
return result.conn, result.err
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
runCancel()
|
|
|
|
|
return nil, ctx.Err()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) {
|
|
|
|
|
if _, err := s.requireSSHClient(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-04-26 20:27:10 +08:00
|
|
|
normalizedReq, err := normalizeForwardRequest(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2026-04-26 10:45:39 +08:00
|
|
|
}
|
2026-04-26 20:27:10 +08:00
|
|
|
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
|
|
|
|
|
return nil, errors.New("local forward listen address is empty")
|
2026-04-26 10:45:39 +08:00
|
|
|
}
|
2026-04-26 20:27:10 +08:00
|
|
|
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwarder := newPortForwarder(listener)
|
2026-04-26 20:27:10 +08:00
|
|
|
forwarder.addCleanup(cleanup)
|
2026-04-26 10:45:39 +08:00
|
|
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
2026-04-26 20:27:10 +08:00
|
|
|
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
})
|
|
|
|
|
return forwarder, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) {
|
|
|
|
|
if _, err := s.requireSSHClient(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-04-26 20:27:10 +08:00
|
|
|
normalizedReq, err := normalizeForwardRequest(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2026-04-26 10:45:39 +08:00
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwardClient, err := s.newForwardDialClient(context.Background())
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = listener.Close()
|
2026-04-26 20:27:10 +08:00
|
|
|
if cleanup != nil {
|
|
|
|
|
_ = cleanup()
|
|
|
|
|
}
|
2026-04-26 10:45:39 +08:00
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwarder := newPortForwarder(listener)
|
2026-04-26 20:27:10 +08:00
|
|
|
forwarder.addCleanup(cleanup)
|
2026-04-26 10:45:39 +08:00
|
|
|
forwarder.addCleanup(func() error {
|
|
|
|
|
return normalizeAlreadyClosedError(forwardClient.Close())
|
|
|
|
|
})
|
|
|
|
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
2026-04-26 20:27:10 +08:00
|
|
|
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
})
|
|
|
|
|
return forwarder, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) {
|
|
|
|
|
client, err := s.requireSSHClient()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-04-26 20:27:10 +08:00
|
|
|
normalizedReq, err := normalizeForwardRequest(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2026-04-26 10:45:39 +08:00
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
dialContext := normalizedReq.DialContext
|
2026-04-26 10:45:39 +08:00
|
|
|
if dialContext == nil {
|
|
|
|
|
dialer := &net.Dialer{
|
|
|
|
|
Timeout: defaultLoginTimeout,
|
|
|
|
|
}
|
|
|
|
|
dialContext = dialer.DialContext
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwarder := newPortForwarder(listener)
|
|
|
|
|
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
|
2026-04-26 20:27:10 +08:00
|
|
|
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
|
2026-04-26 10:45:39 +08:00
|
|
|
})
|
|
|
|
|
return forwarder, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) {
|
|
|
|
|
if _, err := s.requireSSHClient(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(req.ListenAddr) == "" {
|
|
|
|
|
return nil, errors.New("dynamic forward listen address is empty")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
listener, err := net.Listen("tcp", req.ListenAddr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwarder := newPortForwarder(listener)
|
|
|
|
|
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
|
|
|
|
return s.DialTCPContext(ctx, "tcp", targetAddr)
|
|
|
|
|
})
|
|
|
|
|
return forwarder, nil
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
|
|
|
|
|
normalized := normalizedForwardRequest{
|
|
|
|
|
DialContext: req.DialContext,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var err error
|
|
|
|
|
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return normalized, fmt.Errorf("normalize listen address: %w", err)
|
|
|
|
|
}
|
|
|
|
|
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return normalized, fmt.Errorf("normalize target address: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(normalized.ListenAddr) == "" {
|
|
|
|
|
return normalized, errors.New("forward listen address is empty")
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(normalized.TargetAddr) == "" {
|
|
|
|
|
return normalized, errors.New("forward target address is empty")
|
|
|
|
|
}
|
|
|
|
|
return normalized, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func normalizeForwardNetwork(network string) string {
|
|
|
|
|
network = strings.ToLower(strings.TrimSpace(network))
|
|
|
|
|
if network == "" {
|
|
|
|
|
return "tcp"
|
|
|
|
|
}
|
|
|
|
|
return network
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func isSupportedForwardNetwork(network string) bool {
|
|
|
|
|
switch network {
|
|
|
|
|
case "tcp", "tcp4", "tcp6", "unix":
|
|
|
|
|
return true
|
|
|
|
|
default:
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func parseForwardEndpoint(value string) (network string, address string, err error) {
|
|
|
|
|
value = strings.TrimSpace(value)
|
|
|
|
|
if value == "" {
|
|
|
|
|
return "tcp", "", nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lowerValue := strings.ToLower(value)
|
|
|
|
|
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
|
|
|
|
|
if strings.HasPrefix(lowerValue, prefix) {
|
|
|
|
|
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
|
|
|
|
|
address = value[len(prefix):]
|
|
|
|
|
if !isSupportedForwardNetwork(network) {
|
|
|
|
|
return "", "", fmt.Errorf("unsupported forward network %q", network)
|
|
|
|
|
}
|
|
|
|
|
return network, address, nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return "tcp", value, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func forwardEndpoint(network string, address string) string {
|
|
|
|
|
network = normalizeForwardNetwork(network)
|
|
|
|
|
if network == "tcp" {
|
|
|
|
|
return address
|
|
|
|
|
}
|
|
|
|
|
return network + "://" + address
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
|
|
|
|
|
if _, err := s.requireSSHClient(); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if strings.TrimSpace(req.ListenAddr) == "" {
|
|
|
|
|
return nil, errors.New("dynamic forward listen address is empty")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
listener, err := net.Listen("tcp", req.ListenAddr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwardClient, err := s.newForwardDialClient(context.Background())
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = listener.Close()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
forwarder := newPortForwarder(listener)
|
|
|
|
|
forwarder.addCleanup(func() error {
|
|
|
|
|
return normalizeAlreadyClosedError(forwardClient.Close())
|
|
|
|
|
})
|
|
|
|
|
forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) {
|
|
|
|
|
return forwardClient.DialTCPContext(ctx, "tcp", targetAddr)
|
|
|
|
|
})
|
|
|
|
|
return forwarder, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) Addr() net.Addr {
|
|
|
|
|
if f == nil || f.listener == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return f.listener.Addr()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) Wait() error {
|
|
|
|
|
if f == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
<-f.acceptDone
|
|
|
|
|
f.connWG.Wait()
|
|
|
|
|
f.runCleanup()
|
|
|
|
|
return f.Err()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) Err() error {
|
|
|
|
|
if f == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
f.errMu.Lock()
|
|
|
|
|
defer f.errMu.Unlock()
|
|
|
|
|
return f.err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) Close() error {
|
|
|
|
|
if f == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var closeErr error
|
|
|
|
|
f.closeOnce.Do(func() {
|
|
|
|
|
if f.cancel != nil {
|
|
|
|
|
f.cancel()
|
|
|
|
|
}
|
|
|
|
|
if f.listener != nil {
|
|
|
|
|
closeErr = normalizeAlreadyClosedError(f.listener.Close())
|
|
|
|
|
}
|
|
|
|
|
f.closeActiveConnections()
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
<-f.acceptDone
|
|
|
|
|
f.connWG.Wait()
|
|
|
|
|
f.runCleanup()
|
|
|
|
|
if closeErr != nil {
|
|
|
|
|
return closeErr
|
|
|
|
|
}
|
|
|
|
|
return f.Err()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newPortForwarder(listener net.Listener) *PortForwarder {
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
return &PortForwarder{
|
|
|
|
|
listener: listener,
|
|
|
|
|
ctx: ctx,
|
|
|
|
|
cancel: cancel,
|
|
|
|
|
acceptDone: make(chan struct{}),
|
|
|
|
|
conns: make(map[net.Conn]struct{}),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) {
|
|
|
|
|
if s == nil {
|
|
|
|
|
return nil, errors.New("ssh client is nil")
|
|
|
|
|
}
|
|
|
|
|
return newDetachedForwardClient(ctx, s.LoginInfo)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) addCleanup(fn func() error) {
|
|
|
|
|
if f == nil || fn == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.cleanupFns = append(f.cleanupFns, fn)
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 20:27:10 +08:00
|
|
|
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
|
|
|
|
|
network = normalizeForwardNetwork(network)
|
|
|
|
|
if network != "unix" {
|
|
|
|
|
listener, err := net.Listen(network, address)
|
|
|
|
|
return listener, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := removeStaleUnixSocket(address); err != nil {
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
listener, err := net.Listen(network, address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cleanup, err := makeUnixSocketCleanup(address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = listener.Close()
|
|
|
|
|
_ = removeUnixSocketPath(address)
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
return listener, cleanup, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func removeStaleUnixSocket(path string) error {
|
|
|
|
|
info, err := os.Lstat(path)
|
|
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if info.Mode()&os.ModeSocket == 0 {
|
|
|
|
|
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
|
|
|
|
|
if err == nil {
|
|
|
|
|
_ = conn.Close()
|
|
|
|
|
return fmt.Errorf("local unix forward path %q is already in use", path)
|
|
|
|
|
}
|
|
|
|
|
if !isStaleUnixSocketDialError(err) {
|
|
|
|
|
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
|
|
|
|
|
}
|
|
|
|
|
return removeUnixSocketPath(path)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func isStaleUnixSocketDialError(err error) bool {
|
|
|
|
|
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func makeUnixSocketCleanup(path string) (func() error, error) {
|
|
|
|
|
info, err := os.Lstat(path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return func() error {
|
|
|
|
|
current, err := os.Lstat(path)
|
|
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return removeUnixSocketPath(path)
|
|
|
|
|
}, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func removeUnixSocketPath(path string) error {
|
|
|
|
|
err := os.Remove(path)
|
|
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-26 10:45:39 +08:00
|
|
|
func (f *PortForwarder) runCleanup() {
|
|
|
|
|
if f == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.cleanupOnce.Do(func() {
|
|
|
|
|
for _, fn := range f.cleanupFns {
|
|
|
|
|
f.setError(normalizeAlreadyClosedError(fn()))
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) {
|
|
|
|
|
go func() {
|
|
|
|
|
defer close(f.acceptDone)
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
conn, err := f.listener.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
if isClosedListenerError(err) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.setError(err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.trackConn(conn)
|
|
|
|
|
f.connWG.Add(1)
|
|
|
|
|
go func(src net.Conn) {
|
|
|
|
|
defer f.connWG.Done()
|
|
|
|
|
defer f.untrackConn(src)
|
|
|
|
|
|
|
|
|
|
dst, err := targetDial(f.ctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
f.setError(err)
|
|
|
|
|
_ = src.Close()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.trackConn(dst)
|
|
|
|
|
defer f.untrackConn(dst)
|
|
|
|
|
|
|
|
|
|
f.setError(pipeForwardConnections(src, dst))
|
|
|
|
|
}(conn)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) {
|
|
|
|
|
go func() {
|
|
|
|
|
defer close(f.acceptDone)
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
conn, err := f.listener.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
if isClosedListenerError(err) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.setError(err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.trackConn(conn)
|
|
|
|
|
f.connWG.Add(1)
|
|
|
|
|
go func(src net.Conn) {
|
|
|
|
|
defer f.connWG.Done()
|
|
|
|
|
defer f.untrackConn(src)
|
|
|
|
|
if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil {
|
|
|
|
|
f.setError(err)
|
|
|
|
|
}
|
|
|
|
|
}(conn)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) setError(err error) {
|
|
|
|
|
if f == nil || err == nil || f.shouldIgnoreError(err) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.errMu.Lock()
|
|
|
|
|
defer f.errMu.Unlock()
|
|
|
|
|
if f.err == nil {
|
|
|
|
|
f.err = err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func pipeForwardConnections(left net.Conn, right net.Conn) error {
|
|
|
|
|
if left == nil || right == nil {
|
|
|
|
|
if left != nil {
|
|
|
|
|
_ = left.Close()
|
|
|
|
|
}
|
|
|
|
|
if right != nil {
|
|
|
|
|
_ = right.Close()
|
|
|
|
|
}
|
|
|
|
|
return errors.New("forward connection endpoint is nil")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
defer left.Close()
|
|
|
|
|
defer right.Close()
|
|
|
|
|
|
|
|
|
|
var copyWG sync.WaitGroup
|
|
|
|
|
errCh := make(chan error, 2)
|
|
|
|
|
copyWG.Add(2)
|
|
|
|
|
go func() {
|
|
|
|
|
defer copyWG.Done()
|
|
|
|
|
_, err := io.Copy(right, left)
|
|
|
|
|
errCh <- normalizeAlreadyClosedError(err)
|
|
|
|
|
closeWrite(right)
|
|
|
|
|
}()
|
|
|
|
|
go func() {
|
|
|
|
|
defer copyWG.Done()
|
|
|
|
|
_, err := io.Copy(left, right)
|
|
|
|
|
errCh <- normalizeAlreadyClosedError(err)
|
|
|
|
|
closeWrite(left)
|
|
|
|
|
}()
|
|
|
|
|
copyWG.Wait()
|
|
|
|
|
close(errCh)
|
|
|
|
|
|
|
|
|
|
for err := range errCh {
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func closeWrite(conn net.Conn) {
|
|
|
|
|
if conn == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
type closeWriter interface {
|
|
|
|
|
CloseWrite() error
|
|
|
|
|
}
|
|
|
|
|
if writer, ok := conn.(closeWriter); ok {
|
|
|
|
|
_ = writer.CloseWrite()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func handleDynamicForwardConn(
|
|
|
|
|
ctx context.Context,
|
|
|
|
|
src net.Conn,
|
|
|
|
|
targetDial func(context.Context, string) (net.Conn, error),
|
|
|
|
|
trackConn func(net.Conn),
|
|
|
|
|
untrackConn func(net.Conn),
|
|
|
|
|
) error {
|
|
|
|
|
if src == nil {
|
|
|
|
|
return errors.New("dynamic forward source connection is nil")
|
|
|
|
|
}
|
|
|
|
|
defer src.Close()
|
|
|
|
|
|
|
|
|
|
if err := negotiateSOCKS5NoAuth(src); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
targetAddr, replyCode, err := readSOCKS5ConnectTarget(src)
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = writeSOCKS5ServerReply(src, replyCode, nil)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dst, err := targetDial(ctx, targetAddr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = writeSOCKS5ServerReply(src, 0x01, nil)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if trackConn != nil {
|
|
|
|
|
trackConn(dst)
|
|
|
|
|
defer untrackConn(dst)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil {
|
|
|
|
|
_ = dst.Close()
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return pipeForwardConnections(src, dst)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) trackConn(conn net.Conn) {
|
|
|
|
|
if f == nil || conn == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.connMu.Lock()
|
|
|
|
|
defer f.connMu.Unlock()
|
|
|
|
|
f.conns[conn] = struct{}{}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) untrackConn(conn net.Conn) {
|
|
|
|
|
if f == nil || conn == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
f.connMu.Lock()
|
|
|
|
|
defer f.connMu.Unlock()
|
|
|
|
|
delete(f.conns, conn)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) closeActiveConnections() {
|
|
|
|
|
if f == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
f.connMu.Lock()
|
|
|
|
|
conns := make([]net.Conn, 0, len(f.conns))
|
|
|
|
|
for conn := range f.conns {
|
|
|
|
|
conns = append(conns, conn)
|
|
|
|
|
}
|
|
|
|
|
f.connMu.Unlock()
|
|
|
|
|
|
|
|
|
|
for _, conn := range conns {
|
|
|
|
|
_ = conn.Close()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (f *PortForwarder) shouldIgnoreError(err error) bool {
|
|
|
|
|
if err == nil {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
if normalizeAlreadyClosedError(err) == nil {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func negotiateSOCKS5NoAuth(conn net.Conn) error {
|
|
|
|
|
header := make([]byte, 2)
|
|
|
|
|
if _, err := io.ReadFull(conn, header); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if header[0] != 0x05 {
|
|
|
|
|
return errors.New("invalid socks5 version")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
methodCount := int(header[1])
|
|
|
|
|
methods := make([]byte, methodCount)
|
|
|
|
|
if _, err := io.ReadFull(conn, methods); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
method := byte(0xFF)
|
|
|
|
|
for _, candidate := range methods {
|
|
|
|
|
if candidate == 0x00 {
|
|
|
|
|
method = 0x00
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err := conn.Write([]byte{0x05, method}); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if method == 0xFF {
|
|
|
|
|
return errors.New("socks5 client does not support no-auth method")
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) {
|
|
|
|
|
header := make([]byte, 4)
|
|
|
|
|
if _, err := io.ReadFull(conn, header); err != nil {
|
|
|
|
|
return "", 0x01, err
|
|
|
|
|
}
|
|
|
|
|
if header[0] != 0x05 {
|
|
|
|
|
return "", 0x01, errors.New("invalid socks5 request version")
|
|
|
|
|
}
|
|
|
|
|
if header[1] != 0x01 {
|
|
|
|
|
return "", 0x07, errors.New("unsupported socks5 command")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
host, err := readSOCKS5RequestHost(conn, header[3])
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", 0x08, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
portBytes := make([]byte, 2)
|
|
|
|
|
if _, err := io.ReadFull(conn, portBytes); err != nil {
|
|
|
|
|
return "", 0x01, err
|
|
|
|
|
}
|
|
|
|
|
port := int(portBytes[0])<<8 | int(portBytes[1])
|
|
|
|
|
return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) {
|
|
|
|
|
switch addressType {
|
|
|
|
|
case 0x01:
|
|
|
|
|
buffer := make([]byte, 4)
|
|
|
|
|
if _, err := io.ReadFull(conn, buffer); err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
return net.IP(buffer).String(), nil
|
|
|
|
|
case 0x03:
|
|
|
|
|
size := make([]byte, 1)
|
|
|
|
|
if _, err := io.ReadFull(conn, size); err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
buffer := make([]byte, int(size[0]))
|
|
|
|
|
if _, err := io.ReadFull(conn, buffer); err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
return string(buffer), nil
|
|
|
|
|
case 0x04:
|
|
|
|
|
buffer := make([]byte, 16)
|
|
|
|
|
if _, err := io.ReadFull(conn, buffer); err != nil {
|
|
|
|
|
return "", err
|
|
|
|
|
}
|
|
|
|
|
return net.IP(buffer).String(), nil
|
|
|
|
|
default:
|
|
|
|
|
return "", errors.New("unsupported socks5 address type")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error {
|
|
|
|
|
reply := []byte{0x05, replyCode, 0x00}
|
|
|
|
|
|
|
|
|
|
if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil {
|
|
|
|
|
if ip4 := tcpAddr.IP.To4(); ip4 != nil {
|
|
|
|
|
reply = append(reply, 0x01)
|
|
|
|
|
reply = append(reply, ip4...)
|
|
|
|
|
} else if ip16 := tcpAddr.IP.To16(); ip16 != nil {
|
|
|
|
|
reply = append(reply, 0x04)
|
|
|
|
|
reply = append(reply, ip16...)
|
|
|
|
|
}
|
|
|
|
|
if len(reply) > 3 {
|
|
|
|
|
port := tcpAddr.Port
|
|
|
|
|
reply = append(reply, byte(port>>8), byte(port))
|
|
|
|
|
_, err := conn.Write(reply)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)
|
|
|
|
|
_, err := conn.Write(reply)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func isClosedListenerError(err error) bool {
|
|
|
|
|
if err == nil {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, io.EOF) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
return strings.Contains(err.Error(), "use of closed network connection")
|
|
|
|
|
}
|