165 lines
4.5 KiB
Go
165 lines
4.5 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"sync/atomic"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
|
||
|
|
oldDialSSHClient := dialSSHClient
|
||
|
|
oldNewDetachedForwardClient := newDetachedForwardClient
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
dialSSHClient = oldDialSSHClient
|
||
|
|
newDetachedForwardClient = oldNewDetachedForwardClient
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
baseClient := &ssh.Client{}
|
||
|
|
star := &StarSSH{}
|
||
|
|
star.setTransport(baseClient, nil)
|
||
|
|
|
||
|
|
var detachedCalls atomic.Int32
|
||
|
|
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||
|
|
detachedCalls.Add(1)
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||
|
|
if client != baseClient {
|
||
|
|
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
|
||
|
|
}
|
||
|
|
serverConn, clientConn := net.Pipe()
|
||
|
|
go echoForwardPipe(serverConn)
|
||
|
|
return clientConn, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
t.Fatal("default local forward should not close the main ssh client")
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
forwarder, err := star.StartLocalForward(ForwardRequest{
|
||
|
|
ListenAddr: "127.0.0.1:0",
|
||
|
|
TargetAddr: "example.internal:22",
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("start local forward: %v", err)
|
||
|
|
}
|
||
|
|
defer forwarder.Close()
|
||
|
|
|
||
|
|
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
|
||
|
|
if string(reply) != "ping" {
|
||
|
|
t.Fatalf("unexpected forwarded reply: %q", string(reply))
|
||
|
|
}
|
||
|
|
if detachedCalls.Load() != 0 {
|
||
|
|
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
|
||
|
|
oldDialSSHClient := dialSSHClient
|
||
|
|
oldNewDetachedForwardClient := newDetachedForwardClient
|
||
|
|
oldCloseSSHClient := closeSSHClient
|
||
|
|
t.Cleanup(func() {
|
||
|
|
dialSSHClient = oldDialSSHClient
|
||
|
|
newDetachedForwardClient = oldNewDetachedForwardClient
|
||
|
|
closeSSHClient = oldCloseSSHClient
|
||
|
|
})
|
||
|
|
|
||
|
|
baseClient := &ssh.Client{}
|
||
|
|
detachedClient := &ssh.Client{}
|
||
|
|
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
|
||
|
|
star.setTransport(baseClient, nil)
|
||
|
|
|
||
|
|
forwardClient := &StarSSH{}
|
||
|
|
forwardClient.setTransport(detachedClient, nil)
|
||
|
|
|
||
|
|
var detachedCalls atomic.Int32
|
||
|
|
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
|
||
|
|
detachedCalls.Add(1)
|
||
|
|
return forwardClient, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
|
||
|
|
if client != detachedClient {
|
||
|
|
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
|
||
|
|
}
|
||
|
|
serverConn, clientConn := net.Pipe()
|
||
|
|
go echoForwardPipe(serverConn)
|
||
|
|
return clientConn, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var closeCalls atomic.Int32
|
||
|
|
closeSSHClient = func(client sshClientRequester) error {
|
||
|
|
closeCalls.Add(1)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
|
||
|
|
ListenAddr: "127.0.0.1:0",
|
||
|
|
TargetAddr: "example.internal:22",
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("start detached local forward: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
|
||
|
|
if string(reply) != "pong" {
|
||
|
|
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := forwarder.Close(); err != nil {
|
||
|
|
t.Fatalf("close detached local forward: %v", err)
|
||
|
|
}
|
||
|
|
if detachedCalls.Load() != 1 {
|
||
|
|
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
|
||
|
|
}
|
||
|
|
if closeCalls.Load() != 1 {
|
||
|
|
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
|
||
|
|
}
|
||
|
|
if got := star.snapshotSSHClient(); got != baseClient {
|
||
|
|
t.Fatal("detached local forward should not detach the main ssh client")
|
||
|
|
}
|
||
|
|
if got := forwardClient.snapshotSSHClient(); got != nil {
|
||
|
|
t.Fatal("detached local forward should close its detached ssh client")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func echoForwardPipe(conn net.Conn) {
|
||
|
|
defer conn.Close()
|
||
|
|
buf := make([]byte, 4096)
|
||
|
|
n, err := conn.Read(buf)
|
||
|
|
if err != nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
_, _ = conn.Write(buf[:n])
|
||
|
|
}
|
||
|
|
|
||
|
|
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("dial forward listener: %v", err)
|
||
|
|
}
|
||
|
|
defer conn.Close()
|
||
|
|
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||
|
|
|
||
|
|
if _, err := conn.Write(payload); err != nil {
|
||
|
|
t.Fatalf("write forwarded payload: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
reply := make([]byte, len(payload))
|
||
|
|
if _, err := io.ReadFull(conn, reply); err != nil {
|
||
|
|
t.Fatalf("read forwarded reply: %v", err)
|
||
|
|
}
|
||
|
|
return reply
|
||
|
|
}
|