package starssh import ( "context" "io" "net" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" ) func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient newDetachedForwardClient = oldNewDetachedForwardClient closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} star := &StarSSH{} star.setTransport(baseClient, nil) var detachedCalls atomic.Int32 newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { detachedCalls.Add(1) return nil, nil } dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != baseClient { t.Errorf("expected existing ssh client, got %p want %p", client, baseClient) } serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } closeSSHClient = func(client sshClientRequester) error { t.Fatal("default local forward should not close the main ssh client") return nil } forwarder, err := star.StartLocalForward(ForwardRequest{ ListenAddr: "127.0.0.1:0", TargetAddr: "example.internal:22", }) if err != nil { t.Fatalf("start local forward: %v", err) } defer forwarder.Close() reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping")) if string(reply) != "ping" { t.Fatalf("unexpected forwarded reply: %q", string(reply)) } if detachedCalls.Load() != 0 { t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load()) } } func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient newDetachedForwardClient = oldNewDetachedForwardClient closeSSHClient = oldCloseSSHClient }) baseClient := &ssh.Client{} detachedClient := &ssh.Client{} star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}} star.setTransport(baseClient, nil) forwardClient := &StarSSH{} forwardClient.setTransport(detachedClient, nil) var detachedCalls atomic.Int32 newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { detachedCalls.Add(1) return forwardClient, nil } dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { if client != detachedClient { t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient) } serverConn, clientConn := net.Pipe() go echoForwardPipe(serverConn) return clientConn, nil } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } forwarder, err := star.StartLocalForwardDetached(ForwardRequest{ ListenAddr: "127.0.0.1:0", TargetAddr: "example.internal:22", }) if err != nil { t.Fatalf("start detached local forward: %v", err) } reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong")) if string(reply) != "pong" { t.Fatalf("unexpected detached forwarded reply: %q", string(reply)) } if err := forwarder.Close(); err != nil { t.Fatalf("close detached local forward: %v", err) } if detachedCalls.Load() != 1 { t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load()) } if closeCalls.Load() != 1 { t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != baseClient { t.Fatal("detached local forward should not detach the main ssh client") } if got := forwardClient.snapshotSSHClient(); got != nil { t.Fatal("detached local forward should close its detached ssh client") } } func echoForwardPipe(conn net.Conn) { defer conn.Close() buf := make([]byte, 4096) n, err := conn.Read(buf) if err != nil { return } _, _ = conn.Write(buf[:n]) } func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte { t.Helper() conn, err := net.DialTimeout("tcp", addr, time.Second) if err != nil { t.Fatalf("dial forward listener: %v", err) } defer conn.Close() _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) if _, err := conn.Write(payload); err != nil { t.Fatalf("write forwarded payload: %v", err) } reply := make([]byte, len(payload)) if _, err := io.ReadFull(conn, reply); err != nil { t.Fatalf("read forwarded reply: %v", err) } return reply }