package starssh import ( "context" "errors" "io" "net" "strconv" "strings" "sync" "golang.org/x/crypto/ssh" ) type ForwardRequest struct { ListenAddr string TargetAddr string DialContext DialContextFunc } type DynamicForwardRequest struct { ListenAddr string } type PortForwarder struct { listener net.Listener ctx context.Context cancel context.CancelFunc acceptDone chan struct{} connWG sync.WaitGroup closeOnce sync.Once connMu sync.Mutex conns map[net.Conn]struct{} errMu sync.Mutex err error cleanupOnce sync.Once cleanupFns []func() error } var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { return client.Dial(network, address) } var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { if ctx == nil { ctx = context.Background() } return LoginContext(ctx, input) } func (s *StarSSH) DialTCP(network string, address string) (net.Conn, error) { return s.DialTCPContext(context.Background(), network, address) } func (s *StarSSH) DialTCPContext(ctx context.Context, network string, address string) (net.Conn, error) { return s.dialTCPContext(ctx, network, address, nil) } func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network string, address string) (net.Conn, error) { return s.dialTCPContext(ctx, network, address, s.Close) } func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) { if ctx == nil { ctx = context.Background() } if strings.TrimSpace(network) == "" { network = "tcp" } if strings.TrimSpace(address) == "" { return nil, errors.New("forward address is empty") } type dialResult struct { conn net.Conn err error } client, err := s.requireSSHClient() if err != nil { return nil, err } runCancel := func() {} if onCancel != nil { var cancelOnce sync.Once runCancel = func() { cancelOnce.Do(func() { _ = onCancel() }) } cancelDone := make(chan struct{}) defer close(cancelDone) go func() { select { case <-ctx.Done(): runCancel() case <-cancelDone: } }() } dialFunc := dialSSHClient resultCh := make(chan dialResult, 1) go func() { conn, err := dialFunc(ctx, client, network, address) if ctx.Err() != nil && conn != nil { _ = conn.Close() conn = nil } select { case resultCh <- dialResult{conn: conn, err: err}: default: if conn != nil { _ = conn.Close() } } }() select { case result := <-resultCh: return result.conn, result.err case <-ctx.Done(): runCancel() return nil, ctx.Err() } } func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) { if _, err := s.requireSSHClient(); err != nil { return nil, err } if strings.TrimSpace(req.ListenAddr) == "" { return nil, errors.New("local forward listen address is empty") } if strings.TrimSpace(req.TargetAddr) == "" { return nil, errors.New("local forward target address is empty") } listener, err := net.Listen("tcp", req.ListenAddr) if err != nil { return nil, err } forwarder := newPortForwarder(listener) forwarder.serve(func(ctx context.Context) (net.Conn, error) { return s.DialTCPContext(ctx, "tcp", req.TargetAddr) }) return forwarder, nil } func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, error) { if _, err := s.requireSSHClient(); err != nil { return nil, err } if strings.TrimSpace(req.ListenAddr) == "" { return nil, errors.New("local forward listen address is empty") } if strings.TrimSpace(req.TargetAddr) == "" { return nil, errors.New("local forward target address is empty") } listener, err := net.Listen("tcp", req.ListenAddr) if err != nil { return nil, err } forwardClient, err := s.newForwardDialClient(context.Background()) if err != nil { _ = listener.Close() return nil, err } forwarder := newPortForwarder(listener) forwarder.addCleanup(func() error { return normalizeAlreadyClosedError(forwardClient.Close()) }) forwarder.serve(func(ctx context.Context) (net.Conn, error) { return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr) }) return forwarder, nil } func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) { client, err := s.requireSSHClient() if err != nil { return nil, err } if strings.TrimSpace(req.ListenAddr) == "" { return nil, errors.New("remote forward listen address is empty") } if strings.TrimSpace(req.TargetAddr) == "" { return nil, errors.New("remote forward target address is empty") } listener, err := client.Listen("tcp", req.ListenAddr) if err != nil { return nil, err } dialContext := req.DialContext if dialContext == nil { dialer := &net.Dialer{ Timeout: defaultLoginTimeout, } dialContext = dialer.DialContext } forwarder := newPortForwarder(listener) forwarder.serve(func(ctx context.Context) (net.Conn, error) { return dialContext(ctx, "tcp", req.TargetAddr) }) return forwarder, nil } func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder, error) { if _, err := s.requireSSHClient(); err != nil { return nil, err } if strings.TrimSpace(req.ListenAddr) == "" { return nil, errors.New("dynamic forward listen address is empty") } listener, err := net.Listen("tcp", req.ListenAddr) if err != nil { return nil, err } forwarder := newPortForwarder(listener) forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) { return s.DialTCPContext(ctx, "tcp", targetAddr) }) return forwarder, nil } func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) { if _, err := s.requireSSHClient(); err != nil { return nil, err } if strings.TrimSpace(req.ListenAddr) == "" { return nil, errors.New("dynamic forward listen address is empty") } listener, err := net.Listen("tcp", req.ListenAddr) if err != nil { return nil, err } forwardClient, err := s.newForwardDialClient(context.Background()) if err != nil { _ = listener.Close() return nil, err } forwarder := newPortForwarder(listener) forwarder.addCleanup(func() error { return normalizeAlreadyClosedError(forwardClient.Close()) }) forwarder.serveDynamic(func(ctx context.Context, targetAddr string) (net.Conn, error) { return forwardClient.DialTCPContext(ctx, "tcp", targetAddr) }) return forwarder, nil } func (f *PortForwarder) Addr() net.Addr { if f == nil || f.listener == nil { return nil } return f.listener.Addr() } func (f *PortForwarder) Wait() error { if f == nil { return nil } <-f.acceptDone f.connWG.Wait() f.runCleanup() return f.Err() } func (f *PortForwarder) Err() error { if f == nil { return nil } f.errMu.Lock() defer f.errMu.Unlock() return f.err } func (f *PortForwarder) Close() error { if f == nil { return nil } var closeErr error f.closeOnce.Do(func() { if f.cancel != nil { f.cancel() } if f.listener != nil { closeErr = normalizeAlreadyClosedError(f.listener.Close()) } f.closeActiveConnections() }) <-f.acceptDone f.connWG.Wait() f.runCleanup() if closeErr != nil { return closeErr } return f.Err() } func newPortForwarder(listener net.Listener) *PortForwarder { ctx, cancel := context.WithCancel(context.Background()) return &PortForwarder{ listener: listener, ctx: ctx, cancel: cancel, acceptDone: make(chan struct{}), conns: make(map[net.Conn]struct{}), } } func (s *StarSSH) newForwardDialClient(ctx context.Context) (*StarSSH, error) { if s == nil { return nil, errors.New("ssh client is nil") } return newDetachedForwardClient(ctx, s.LoginInfo) } func (f *PortForwarder) addCleanup(fn func() error) { if f == nil || fn == nil { return } f.cleanupFns = append(f.cleanupFns, fn) } func (f *PortForwarder) runCleanup() { if f == nil { return } f.cleanupOnce.Do(func() { for _, fn := range f.cleanupFns { f.setError(normalizeAlreadyClosedError(fn())) } }) } func (f *PortForwarder) serve(targetDial func(context.Context) (net.Conn, error)) { go func() { defer close(f.acceptDone) for { conn, err := f.listener.Accept() if err != nil { if isClosedListenerError(err) { return } f.setError(err) return } f.trackConn(conn) f.connWG.Add(1) go func(src net.Conn) { defer f.connWG.Done() defer f.untrackConn(src) dst, err := targetDial(f.ctx) if err != nil { f.setError(err) _ = src.Close() return } f.trackConn(dst) defer f.untrackConn(dst) f.setError(pipeForwardConnections(src, dst)) }(conn) } }() } func (f *PortForwarder) serveDynamic(targetDial func(context.Context, string) (net.Conn, error)) { go func() { defer close(f.acceptDone) for { conn, err := f.listener.Accept() if err != nil { if isClosedListenerError(err) { return } f.setError(err) return } f.trackConn(conn) f.connWG.Add(1) go func(src net.Conn) { defer f.connWG.Done() defer f.untrackConn(src) if err := handleDynamicForwardConn(f.ctx, src, targetDial, f.trackConn, f.untrackConn); err != nil { f.setError(err) } }(conn) } }() } func (f *PortForwarder) setError(err error) { if f == nil || err == nil || f.shouldIgnoreError(err) { return } f.errMu.Lock() defer f.errMu.Unlock() if f.err == nil { f.err = err } } func pipeForwardConnections(left net.Conn, right net.Conn) error { if left == nil || right == nil { if left != nil { _ = left.Close() } if right != nil { _ = right.Close() } return errors.New("forward connection endpoint is nil") } defer left.Close() defer right.Close() var copyWG sync.WaitGroup errCh := make(chan error, 2) copyWG.Add(2) go func() { defer copyWG.Done() _, err := io.Copy(right, left) errCh <- normalizeAlreadyClosedError(err) closeWrite(right) }() go func() { defer copyWG.Done() _, err := io.Copy(left, right) errCh <- normalizeAlreadyClosedError(err) closeWrite(left) }() copyWG.Wait() close(errCh) for err := range errCh { if err != nil { return err } } return nil } func closeWrite(conn net.Conn) { if conn == nil { return } type closeWriter interface { CloseWrite() error } if writer, ok := conn.(closeWriter); ok { _ = writer.CloseWrite() } } func handleDynamicForwardConn( ctx context.Context, src net.Conn, targetDial func(context.Context, string) (net.Conn, error), trackConn func(net.Conn), untrackConn func(net.Conn), ) error { if src == nil { return errors.New("dynamic forward source connection is nil") } defer src.Close() if err := negotiateSOCKS5NoAuth(src); err != nil { return err } targetAddr, replyCode, err := readSOCKS5ConnectTarget(src) if err != nil { _ = writeSOCKS5ServerReply(src, replyCode, nil) return err } dst, err := targetDial(ctx, targetAddr) if err != nil { _ = writeSOCKS5ServerReply(src, 0x01, nil) return err } if trackConn != nil { trackConn(dst) defer untrackConn(dst) } if err := writeSOCKS5ServerReply(src, 0x00, dst.LocalAddr()); err != nil { _ = dst.Close() return err } return pipeForwardConnections(src, dst) } func (f *PortForwarder) trackConn(conn net.Conn) { if f == nil || conn == nil { return } f.connMu.Lock() defer f.connMu.Unlock() f.conns[conn] = struct{}{} } func (f *PortForwarder) untrackConn(conn net.Conn) { if f == nil || conn == nil { return } f.connMu.Lock() defer f.connMu.Unlock() delete(f.conns, conn) } func (f *PortForwarder) closeActiveConnections() { if f == nil { return } f.connMu.Lock() conns := make([]net.Conn, 0, len(f.conns)) for conn := range f.conns { conns = append(conns, conn) } f.connMu.Unlock() for _, conn := range conns { _ = conn.Close() } } func (f *PortForwarder) shouldIgnoreError(err error) bool { if err == nil { return true } if normalizeAlreadyClosedError(err) == nil { return true } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true } return false } func negotiateSOCKS5NoAuth(conn net.Conn) error { header := make([]byte, 2) if _, err := io.ReadFull(conn, header); err != nil { return err } if header[0] != 0x05 { return errors.New("invalid socks5 version") } methodCount := int(header[1]) methods := make([]byte, methodCount) if _, err := io.ReadFull(conn, methods); err != nil { return err } method := byte(0xFF) for _, candidate := range methods { if candidate == 0x00 { method = 0x00 break } } if _, err := conn.Write([]byte{0x05, method}); err != nil { return err } if method == 0xFF { return errors.New("socks5 client does not support no-auth method") } return nil } func readSOCKS5ConnectTarget(conn net.Conn) (string, byte, error) { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return "", 0x01, err } if header[0] != 0x05 { return "", 0x01, errors.New("invalid socks5 request version") } if header[1] != 0x01 { return "", 0x07, errors.New("unsupported socks5 command") } host, err := readSOCKS5RequestHost(conn, header[3]) if err != nil { return "", 0x08, err } portBytes := make([]byte, 2) if _, err := io.ReadFull(conn, portBytes); err != nil { return "", 0x01, err } port := int(portBytes[0])<<8 | int(portBytes[1]) return net.JoinHostPort(host, strconv.Itoa(port)), 0x00, nil } func readSOCKS5RequestHost(conn net.Conn, addressType byte) (string, error) { switch addressType { case 0x01: buffer := make([]byte, 4) if _, err := io.ReadFull(conn, buffer); err != nil { return "", err } return net.IP(buffer).String(), nil case 0x03: size := make([]byte, 1) if _, err := io.ReadFull(conn, size); err != nil { return "", err } buffer := make([]byte, int(size[0])) if _, err := io.ReadFull(conn, buffer); err != nil { return "", err } return string(buffer), nil case 0x04: buffer := make([]byte, 16) if _, err := io.ReadFull(conn, buffer); err != nil { return "", err } return net.IP(buffer).String(), nil default: return "", errors.New("unsupported socks5 address type") } } func writeSOCKS5ServerReply(conn net.Conn, replyCode byte, addr net.Addr) error { reply := []byte{0x05, replyCode, 0x00} if tcpAddr, ok := addr.(*net.TCPAddr); ok && tcpAddr != nil { if ip4 := tcpAddr.IP.To4(); ip4 != nil { reply = append(reply, 0x01) reply = append(reply, ip4...) } else if ip16 := tcpAddr.IP.To16(); ip16 != nil { reply = append(reply, 0x04) reply = append(reply, ip16...) } if len(reply) > 3 { port := tcpAddr.Port reply = append(reply, byte(port>>8), byte(port)) _, err := conn.Write(reply) return err } } reply = append(reply, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00) _, err := conn.Write(reply) return err } func isClosedListenerError(err error) bool { if err == nil { return false } if errors.Is(err, io.EOF) { return true } if errors.Is(err, net.ErrClosed) { return true } return strings.Contains(err.Error(), "use of closed network connection") }