package starssh import ( "context" "errors" "io" "net" "os" "path/filepath" "runtime" "sync" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" ) type stubListener struct { addr net.Addr acceptCh chan net.Conn closeCh chan struct{} closeOnce sync.Once } type dialRecord struct { network string addr string } func newStubListener(addr net.Addr) *stubListener { return &stubListener{ addr: addr, acceptCh: make(chan net.Conn, 1), closeCh: make(chan struct{}), } } func (l *stubListener) Accept() (net.Conn, error) { select { case conn, ok := <-l.acceptCh: if !ok { return nil, io.EOF } return conn, nil case <-l.closeCh: return nil, net.ErrClosed } } func (l *stubListener) Close() error { l.closeOnce.Do(func() { close(l.closeCh) close(l.acceptCh) }) return nil } func (l *stubListener) Addr() net.Addr { return l.addr } func (l *stubListener) Push(conn net.Conn) error { select { case <-l.closeCh: return net.ErrClosed case l.acceptCh <- conn: return nil } } func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient newDetachedForwardClient = oldNewDetachedForwardClient closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) var detachedCalls atomic.Int32 newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { detachedCalls.Add(1) return nil, nil } dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != baseClient { t.Errorf("expected existing ssh client, got %p want %p", client, baseClient) } serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } closeSSHClient = func(client sshClientRequester) error { t.Fatal("default local forward should not close the main ssh client") return nil } forwarder, err := star.StartLocalForward(ForwardRequest{ ListenAddr: "127.0.0.1:0", TargetAddr: "example.internal:22", }) if err != nil { t.Fatalf("start local forward: %v", err) } defer forwarder.Close() reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping")) if string(reply) != "ping" { t.Fatalf("unexpected forwarded reply: %q", string(reply)) } if detachedCalls.Load() != 0 { t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load()) } } func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) { dialer := func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil } req, err := normalizeForwardRequest(ForwardRequest{ "127.0.0.1:10022", "example.internal:22", dialer, }) if err != nil { t.Fatalf("normalizeForwardRequest: %v", err) } if req.ListenNetwork != "tcp" { t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork) } if req.TargetNetwork != "tcp" { t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork) } if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" { t.Fatalf("unexpected normalized request: %+v", req) } if req.DialContext == nil { t.Fatal("expected DialContext to be preserved") } } func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) { network, address, err := parseForwardEndpoint("tcp:22") if err != nil { t.Fatalf("parseForwardEndpoint: %v", err) } if network != "tcp" { t.Fatalf("network=%q want tcp", network) } if address != "tcp:22" { t.Fatalf("address=%q want tcp:22", address) } } func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) { network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock") if err != nil { t.Fatalf("parseForwardEndpoint unix: %v", err) } if network != "unix" || address != "/tmp/test-forward.sock" { t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address) } network, address, err = parseForwardEndpoint("tcp6://[::1]:2222") if err != nil { t.Fatalf("parseForwardEndpoint tcp6: %v", err) } if network != "tcp6" || address != "[::1]:2222" { t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address) } } func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient newDetachedForwardClient = oldNewDetachedForwardClient closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} detachedClient := &ssh.Client{} star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}} star.setTransport(baseClient, nil) forwardClient := &StarSSH{} forwardClient.setTransport(detachedClient, nil) var detachedCalls atomic.Int32 newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { detachedCalls.Add(1) return forwardClient, nil } dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != detachedClient { t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient) } serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } forwarder, err := star.StartLocalForwardDetached(ForwardRequest{ ListenAddr: "127.0.0.1:0", TargetAddr: "example.internal:22", }) if err != nil { t.Fatalf("start detached local forward: %v", err) } reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong")) if string(reply) != "pong" { t.Fatalf("unexpected detached forwarded reply: %q", string(reply)) } if err := forwarder.Close(); err != nil { t.Fatalf("close detached local forward: %v", err) } if detachedCalls.Load() != 1 { t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load()) } if closeCalls.Load() != 1 { t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != baseClient { t.Fatal("detached local forward should not detach the main ssh client") } if got := forwardClient.snapshotSSHClient(); got != nil { t.Fatal("detached local forward should close its detached ssh client") } } func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) { oldListenSSHClient := listenSSHClient t.Cleanup(func() { listenSSHClient = oldListenSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) listener := newStubListener(&net.UnixAddr{ Name: "/run/user/0/gnupg/S.gpg-agent", Net: "unix", }) var listenedNetwork string var listenedAddr string listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } listenedNetwork = network listenedAddr = address return listener, nil } var targetNetwork string var targetAddr string forwarder, err := star.StartRemoteForward(ForwardRequest{ ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"), TargetAddr: "127.0.0.1:4321", DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { targetNetwork = network targetAddr = address serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil }, }) if err != nil { t.Fatalf("start remote unix forward: %v", err) } defer forwarder.Close() srcPeer, forwardedConn := net.Pipe() defer srcPeer.Close() if err := listener.Push(forwardedConn); err != nil { t.Fatalf("push forwarded connection: %v", err) } payload := []byte("unix-forward") done := make(chan []byte, 1) go func() { reply := make([]byte, len(payload)) _, _ = io.ReadFull(srcPeer, reply) done <- reply }() if _, err := srcPeer.Write(payload); err != nil { t.Fatalf("write source payload: %v", err) } select { case reply := <-done: if string(reply) != string(payload) { t.Fatalf("unexpected remote unix forward reply: %q", string(reply)) } case <-time.After(2 * time.Second): t.Fatal("remote unix forward did not relay payload") } if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" { t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr) } if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" { t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr) } } func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") } oldDialSSHClient := dialSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) var targetNetwork string var targetAddr string dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } targetNetwork = network targetAddr = address serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } socketPath := filepath.Join(t.TempDir(), "forward.sock") forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") if err != nil { t.Fatalf("start local unix forward: %v", err) } defer func() { closeErr := forwarder.Close() if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { t.Fatalf("close local unix forward: %v", closeErr) } }() conn, err := net.DialTimeout("unix", socketPath, time.Second) if err != nil { t.Fatalf("dial unix forward listener: %v", err) } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) payload := []byte("unix-local-forward") if _, err := conn.Write(payload); err != nil { t.Fatalf("write unix forward payload: %v", err) } reply := make([]byte, len(payload)) if _, err := io.ReadFull(conn, reply); err != nil { t.Fatalf("read unix forward reply: %v", err) } if string(reply) != string(payload) { t.Fatalf("unexpected unix forward reply: %q", string(reply)) } if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" { t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr) } } func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") } oldDialSSHClient := dialSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } socketPath := filepath.Join(t.TempDir(), "cleanup.sock") forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") if err != nil { t.Fatalf("start local unix forward: %v", err) } if _, err := os.Lstat(socketPath); err != nil { t.Fatalf("socket should exist while forward is running: %v", err) } if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { t.Fatalf("close local unix forward: %v", err) } if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) { t.Fatalf("socket path should be removed on close, got err=%v", err) } } func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") } oldDialSSHClient := dialSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } socketPath := filepath.Join(t.TempDir(), "stale.sock") staleListener, err := net.ListenUnix("unix", &net.UnixAddr{ Name: socketPath, Net: "unix", }) if err != nil { t.Fatalf("create stale unix socket: %v", err) } staleListener.SetUnlinkOnClose(false) if err := staleListener.Close(); err != nil { t.Fatalf("close stale unix socket listener: %v", err) } if _, err := os.Lstat(socketPath); err != nil { t.Fatalf("expected stale unix socket path to remain after close: %v", err) } forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") if err != nil { t.Fatalf("start local unix forward on stale socket path: %v", err) } defer func() { closeErr := forwarder.Close() if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { t.Fatalf("close local unix forward: %v", closeErr) } }() reply := make([]byte, len("stale-reuse")) conn, err := net.DialTimeout("unix", socketPath, time.Second) if err != nil { t.Fatalf("dial reused unix forward listener: %v", err) } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) if _, err := conn.Write([]byte("stale-reuse")); err != nil { t.Fatalf("write reused unix forward payload: %v", err) } if _, err := io.ReadFull(conn, reply); err != nil { t.Fatalf("read reused unix forward reply: %v", err) } if string(reply) != "stale-reuse" { t.Fatalf("unexpected reply on reused unix forward: %q", string(reply)) } } func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") } oldDialSSHClient := dialSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) targetSocketPath := filepath.Join(t.TempDir(), "target.sock") targetListener, err := net.Listen("unix", targetSocketPath) if err != nil { t.Fatalf("listen target unix socket: %v", err) } defer targetListener.Close() done := make(chan []byte, 1) go func() { conn, acceptErr := targetListener.Accept() if acceptErr != nil { done <- nil return } defer conn.Close() buf := make([]byte, 64) n, _ := conn.Read(buf) _, _ = conn.Write(buf[:n]) done <- buf[:n] }() dialRecordCh := make(chan dialRecord, 1) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } dialRecordCh <- dialRecord{network: network, addr: address} var dialer net.Dialer return dialer.DialContext(ctx, network, address) } listenSocketPath := filepath.Join(t.TempDir(), "listen.sock") forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath) if err != nil { t.Fatalf("start local unix-to-unix forward: %v", err) } defer func() { closeErr := forwarder.Close() if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { t.Fatalf("close local unix-to-unix forward: %v", closeErr) } }() conn, err := net.DialTimeout("unix", listenSocketPath, time.Second) if err != nil { t.Fatalf("dial unix-to-unix listener: %v", err) } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) payload := []byte("unix-to-unix") if _, err := conn.Write(payload); err != nil { t.Fatalf("write unix-to-unix payload: %v", err) } reply := make([]byte, len(payload)) if _, err := io.ReadFull(conn, reply); err != nil { t.Fatalf("read unix-to-unix reply: %v", err) } if string(reply) != string(payload) { t.Fatalf("unexpected unix-to-unix reply: %q", string(reply)) } select { case got := <-done: if string(got) != string(payload) { t.Fatalf("unexpected payload seen by target unix socket: %q", string(got)) } case <-time.After(2 * time.Second): t.Fatal("target unix socket did not receive forwarded payload") } select { case got := <-dialRecordCh: if got.network != "unix" || got.addr != targetSocketPath { t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr) } case <-time.After(2 * time.Second): t.Fatal("did not observe unix target dial") } } func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") } oldDialSSHClient := dialSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock") targetListener, err := net.Listen("unix", targetSocketPath) if err != nil { t.Fatalf("listen target unix socket: %v", err) } defer targetListener.Close() done := make(chan []byte, 1) go func() { conn, acceptErr := targetListener.Accept() if acceptErr != nil { done <- nil return } defer conn.Close() buf := make([]byte, 64) n, _ := conn.Read(buf) _, _ = conn.Write(buf[:n]) done <- buf[:n] }() dialRecordCh := make(chan dialRecord, 1) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != baseClient { t.Fatalf("unexpected ssh client %p", client) } dialRecordCh <- dialRecord{network: network, addr: address} var dialer net.Dialer return dialer.DialContext(ctx, network, address) } forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath) if err != nil { t.Fatalf("start local tcp-to-unix forward: %v", err) } defer func() { closeErr := forwarder.Close() if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { t.Fatalf("close local tcp-to-unix forward: %v", closeErr) } }() reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix")) if string(reply) != "tcp-to-unix" { t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply)) } select { case got := <-done: if string(got) != "tcp-to-unix" { t.Fatalf("unexpected payload seen by unix target: %q", string(got)) } case <-time.After(2 * time.Second): t.Fatal("unix target did not receive forwarded tcp payload") } select { case got := <-dialRecordCh: if got.network != "unix" || got.addr != targetSocketPath { t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr) } case <-time.After(2 * time.Second): t.Fatal("did not observe unix target dial") } } func echoForwardPipe(conn net.Conn) { defer conn.Close() buf := make([]byte, 4096) n, err := conn.Read(buf) if err != nil { return } _, _ = conn.Write(buf[:n]) } func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte { t.Helper() conn, err := net.DialTimeout("tcp", addr, time.Second) if err != nil { t.Fatalf("dial forward listener: %v", err) } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) if _, err := conn.Write(payload); err != nil { t.Fatalf("write forwarded payload: %v", err) } reply := make([]byte, len(payload)) if _, err := io.ReadFull(conn, reply); err != nil { t.Fatalf("read forwarded reply: %v", err) } return reply }