Files
starssh/cancel_semantics_test.go
T

173 lines
4.7 KiB
Go
Raw Normal View History

package starssh
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContext(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after PingContext cancel")
}
}
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
oldSendKeepAliveRequest := sendKeepAliveRequest
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
sendKeepAliveRequest = oldSendKeepAliveRequest
closeSSHClient = oldCloseSSHClient
})
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := star.PingContextCloseOnCancel(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
}
}
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
client := &ssh.Client{}
star := &StarSSH{}
star.setTransport(client, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 0 {
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != client {
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
}
}
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
closeSSHClient = oldCloseSSHClient
})
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
<-ctx.Done()
time.Sleep(20 * time.Millisecond)
return nil, ctx.Err()
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
if conn != nil {
t.Fatal("expected nil connection on canceled dial")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded, got %v", err)
}
if closeCalls.Load() != 1 {
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != nil {
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
}
}