173 lines
4.7 KiB
Go
173 lines
4.7 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"net"
|
||
|
|
"sync/atomic"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestPingContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||
|
|
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||
|
|
<-ctx.Done()
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeCalls atomic.Int32
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
closeCalls.Add(1)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
client := &ssh.Client{}
|
||
|
|
star := &StarSSH{}
|
||
|
|
star.setTransport(client, nil)
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
err := star.PingContext(ctx)
|
||
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||
|
|
t.Fatalf("expected deadline exceeded, got %v", err)
|
||
|
|
}
|
||
|
|
if closeCalls.Load() != 0 {
|
||
|
|
t.Fatalf("expected PingContext to keep connection open, close calls=%d", closeCalls.Load())
|
||
|
|
}
|
||
|
|
if got := star.snapshotSSHClient(); got != client {
|
||
|
|
t.Fatal("expected ssh client to remain attached after PingContext cancel")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestPingContextCloseOnCancelClosesConnection(t *testing.T) {
|
||
|
|
oldSendKeepAliveRequest := sendKeepAliveRequest
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
sendKeepAliveRequest = oldSendKeepAliveRequest
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
sendKeepAliveRequest = func(ctx context.Context, client sshClientRequester) error {
|
||
|
|
<-ctx.Done()
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeCalls atomic.Int32
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
closeCalls.Add(1)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
star := &StarSSH{}
|
||
|
|
star.setTransport(&ssh.Client{}, nil)
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
err := star.PingContextCloseOnCancel(ctx)
|
||
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||
|
|
t.Fatalf("expected deadline exceeded, got %v", err)
|
||
|
|
}
|
||
|
|
if closeCalls.Load() != 1 {
|
||
|
|
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||
|
|
}
|
||
|
|
if got := star.snapshotSSHClient(); got != nil {
|
||
|
|
t.Fatal("expected ssh client to be detached after PingContextCloseOnCancel")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDialTCPContextDoesNotCloseConnectionOnCancel(t *testing.T) {
|
||
|
|
oldDialSSHClient := dialSSHClient
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
dialSSHClient = oldDialSSHClient
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||
|
|
<-ctx.Done()
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
return nil, ctx.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeCalls atomic.Int32
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
closeCalls.Add(1)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
client := &ssh.Client{}
|
||
|
|
star := &StarSSH{}
|
||
|
|
star.setTransport(client, nil)
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
conn, err := star.DialTCPContext(ctx, "tcp", "127.0.0.1:22")
|
||
|
|
if conn != nil {
|
||
|
|
t.Fatal("expected nil connection on canceled dial")
|
||
|
|
}
|
||
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||
|
|
t.Fatalf("expected deadline exceeded, got %v", err)
|
||
|
|
}
|
||
|
|
if closeCalls.Load() != 0 {
|
||
|
|
t.Fatalf("expected DialTCPContext to keep connection open, close calls=%d", closeCalls.Load())
|
||
|
|
}
|
||
|
|
if got := star.snapshotSSHClient(); got != client {
|
||
|
|
t.Fatal("expected ssh client to remain attached after DialTCPContext cancel")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestDialTCPContextCloseOnCancelClosesConnection(t *testing.T) {
|
||
|
|
oldDialSSHClient := dialSSHClient
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
dialSSHClient = oldDialSSHClient
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||
|
|
<-ctx.Done()
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
return nil, ctx.Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeCalls atomic.Int32
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
closeCalls.Add(1)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
star := &StarSSH{}
|
||
|
|
star.setTransport(&ssh.Client{}, nil)
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
conn, err := star.DialTCPContextCloseOnCancel(ctx, "tcp", "127.0.0.1:22")
|
||
|
|
if conn != nil {
|
||
|
|
t.Fatal("expected nil connection on canceled dial")
|
||
|
|
}
|
||
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||
|
|
t.Fatalf("expected deadline exceeded, got %v", err)
|
||
|
|
}
|
||
|
|
if closeCalls.Load() != 1 {
|
||
|
|
t.Fatalf("expected exactly one close call, got %d", closeCalls.Load())
|
||
|
|
}
|
||
|
|
if got := star.snapshotSSHClient(); got != nil {
|
||
|
|
t.Fatal("expected ssh client to be detached after DialTCPContextCloseOnCancel")
|
||
|
|
}
|
||
|
|
}
|