package starssh import ( "context" "errors" "net" "sync/atomic" "testing" "time" "golang.org/x/crypto/ssh" ) func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) { oldSendKeepAliveRequest := sendKeepAliveRequest oldCloseSSHClient := closeSSHClient t.Cleanup(func() { sendKeepAliveRequest = oldSendKeepAliveRequest closeSSHClient = oldCloseSSHClient }) sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { <-ctx.Done() time.Sleep(20 * time.Millisecond) return ctx.Err() } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } client := &ssh.Client{} star := &StarSSH{} star.setTransport(client, nil) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() err := star.PingContext(ctx) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected deadline exceeded, got %v", err) } if closeCalls.Load() != 0 { t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != client { t.Fatal("expected ssh client to remain attached after PingContext cancel") } } func TestPingContextCloseOnCancelClosesConnection(t *testing.T) { oldSendKeepAliveRequest := sendKeepAliveRequest oldCloseSSHClient := closeSSHClient t.Cleanup(func() { sendKeepAliveRequest = oldSendKeepAliveRequest closeSSHClient = oldCloseSSHClient }) sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error { <-ctx.Done() time.Sleep(20 * time.Millisecond) return ctx.Err() } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() err := star.PingContextCloseOnCancel(ctx) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected deadline exceeded, got %v", err) } if closeCalls.Load() != 1 { t.Fatalf("expected exactly one close call, got %d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != nil { t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel") } } func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) { oldDialSSHClient := dialSSHClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient closeSSHClient = oldCloseSSHClient }) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { <-ctx.Done() time.Sleep(20 * time.Millisecond) return nil, ctx.Err() } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } client := &ssh.Client{} star := &StarSSH{} star.setTransport(client, nil) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22") if conn != nil { t.Fatal("expected nil connection on canceled dial") } if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected deadline exceeded, got %v", err) } if closeCalls.Load() != 0 { t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != client { t.Fatal("expected ssh client to remain attached after DialTCPContext cancel") } } func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) { oldDialSSHClient := dialSSHClient oldCloseSSHClient := closeSSHClient t.Cleanup(func() { dialSSHClient = oldDialSSHClient closeSSHClient = oldCloseSSHClient }) dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { <-ctx.Done() time.Sleep(20 * time.Millisecond) return nil, ctx.Err() } var closeCalls atomic.Int32 closeSSHClient = func(client sshClientRequester) error { closeCalls.Add(1) return nil } star := &StarSSH{} star.setTransport(&ssh.Client{}, nil) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22") if conn != nil { t.Fatal("expected nil connection on canceled dial") } if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected deadline exceeded, got %v", err) } if closeCalls.Load() != 1 { t.Fatalf("expected exactly one close call, got %d", closeCalls.Load()) } if got := star.snapshotSSHClient(); got != nil { t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel") } }