notify/conn_injection_test.go

545 lines
15 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
var errInMemoryListenerClosed = errors.New("in-memory listener closed")
type inMemoryListener struct {
closed chan struct{}
once sync.Once
}
func newInMemoryListener() *inMemoryListener {
return &inMemoryListener{
closed: make(chan struct{}),
}
}
func (l *inMemoryListener) Accept() (net.Conn, error) {
<-l.closed
return nil, errInMemoryListenerClosed
}
func (l *inMemoryListener) Close() error {
l.once.Do(func() {
close(l.closed)
})
return nil
}
func (l *inMemoryListener) Addr() net.Addr {
return inMemoryAddr("in-memory-listener")
}
type inMemoryAddr string
func (a inMemoryAddr) Network() string { return "in-memory" }
func (a inMemoryAddr) String() string { return string(a) }
func TestConnectByConnRequiresModernPSK(t *testing.T) {
client := NewClient()
left, right := net.Pipe()
defer left.Close()
defer right.Close()
err := client.ConnectByConn(left)
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired)
}
}
func TestConnectByConnWithConfiguredSecurity(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
left, right := net.Pipe()
defer right.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
bootstrapPeerAttachConnForTest(t, server, right)
client.SetSecretKey(secret)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestConnectByFactoryRequiresModernPSK(t *testing.T) {
client := NewClient()
called := false
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
called = true
left, right := net.Pipe()
_ = right.Close()
return left, nil
})
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired)
}
if called {
t.Fatal("dialFn should not be called before security validation passes")
}
}
func TestConnectByFactoryRejectsNilDialFn(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
err := client.ConnectByFactory(context.Background(), nil)
if err == nil || err.Error() != "dialFn is nil" {
t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err)
}
}
func TestConnectByFactoryPropagatesDialError(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
wantErr := errors.New("dial failed")
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
return nil, wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr)
}
}
func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
left, right := net.Pipe()
defer right.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
bootstrapPeerAttachConnForTest(t, server, right)
client.SetSecretKey(secret)
if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) {
if ctx == nil {
t.Fatal("ConnectByFactory should normalize nil context")
}
return left, nil
}); err != nil {
t.Fatalf("ConnectByFactory failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
firstDialEntered := make(chan struct{}, 1)
firstDone := make(chan error, 1)
go func() {
firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) {
firstDialEntered <- struct{}{}
<-ctx.Done()
return nil, ctx.Err()
})
}()
select {
case <-firstDialEntered:
case <-time.After(time.Second):
t.Fatal("first connect attempt did not enter dialFn")
}
secondDialCalled := false
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
secondDialCalled = true
return nil, errors.New("second dial should not run")
})
if err == nil || err.Error() != "client already run" {
t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err)
}
if secondDialCalled {
t.Fatal("second dialFn should not be called during first connect start")
}
cancel()
select {
case err = <-firstDone:
case <-time.After(time.Second):
t.Fatal("first ConnectByFactory did not finish after cancel")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled)
}
wantErr := errors.New("dial after rollback")
err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
return nil, wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr)
}
}
func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
bootstrapPeerAttachConnForTest(t, server, firstRight)
if err := client.ConnectByConn(firstLeft); err != nil {
t.Fatalf("initial ConnectByConn failed: %v", err)
}
before := client.clientSessionRuntimeSnapshot()
if before == nil {
t.Fatal("runtime should exist after initial connect")
}
initialEpoch := before.epoch
initialStopCtx := before.stopCtx
initialQueue := before.queue
client.clearClientSessionRuntimeTransport()
recvCh := make(chan Message, 1)
client.SetLink("reattach-public", func(message *Message) {
recvCh <- *message
})
secondLeft, secondRight := net.Pipe()
defer secondRight.Close()
bootstrapPeerAttachConnForTest(t, server, secondRight)
if err := client.ConnectByConn(secondLeft); err != nil {
t.Fatalf("reattach ConnectByConn failed: %v", err)
}
after := client.clientSessionRuntimeSnapshot()
if after == nil {
t.Fatal("runtime should exist after reattach")
}
if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached {
t.Fatalf("reattached runtime mismatch: %+v", after)
}
env, err := wrapTransferMsgEnvelope(TransferMsg{
ID: 88,
Key: "reattach-public",
Value: []byte("ok"),
Type: MSG_ASYNC,
}, client.sequenceEn)
if err != nil {
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
}
wire, err := client.encodeEnvelope(env)
if err != nil {
t.Fatalf("encodeEnvelope failed: %v", err)
}
if _, err := secondRight.Write(wire); err != nil {
t.Fatalf("reattached conn write failed: %v", err)
}
select {
case msg := <-recvCh:
if got, want := msg.Key, "reattach-public"; got != want {
t.Fatalf("message key mismatch: got %q want %q", got, want)
}
if got, want := string(msg.Value), "ok"; got != want {
t.Fatalf("message value mismatch: got %q want %q", got, want)
}
case <-time.After(time.Second):
t.Fatal("reattached public conn did not dispatch message")
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
}
func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
bootstrapPeerAttachConnForTest(t, server, firstRight)
if err := client.ConnectByConn(firstLeft); err != nil {
t.Fatalf("initial ConnectByConn failed: %v", err)
}
before := client.clientSessionRuntimeSnapshot()
if before == nil {
t.Fatal("runtime should exist after initial connect")
}
initialEpoch := before.epoch
client.clearClientSessionRuntimeTransport()
var dialCount atomic.Int32
secondLeft, secondRight := net.Pipe()
defer secondRight.Close()
bootstrapPeerAttachConnForTest(t, server, secondRight)
if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
dialCount.Add(1)
return secondLeft, nil
}); err != nil {
t.Fatalf("reattach ConnectByFactory failed: %v", err)
}
if got, want := dialCount.Load(), int32(1); got != want {
t.Fatalf("dial count mismatch: got %d want %d", got, want)
}
after := client.clientSessionRuntimeSnapshot()
if after == nil {
t.Fatal("runtime should exist after factory reattach")
}
if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached {
t.Fatalf("reattached runtime mismatch: %+v", after)
}
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want {
t.Fatalf("connect source mismatch: got %q want %q", got, want)
}
if !snapshot.CanReconnect {
t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
}
func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
failErr := errors.New("key exchange fail for test")
client.keyExchangeFn = func(Client) error {
return failErr
}
left1, right1 := net.Pipe()
defer right1.Close()
err := client.ConnectByConn(left1)
if !errors.Is(err, failErr) {
t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr)
}
status := client.Status()
if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) {
t.Fatalf("unexpected status after failed key exchange: %+v", status)
}
select {
case <-client.StopMonitorChan():
t.Fatal("StopMonitorChan should remain open after failed connect cleanup")
case <-time.After(20 * time.Millisecond):
}
client.SetSkipExchangeKey(true)
left2, right2 := net.Pipe()
defer right2.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
UseLegacySecurityServer(server)
})
bootstrapPeerAttachConnForTest(t, server, right2)
if err := client.ConnectByConn(left2); err != nil {
t.Fatalf("ConnectByConn second attempt failed: %v", err)
}
if !client.Status().Alive {
t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status())
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestListenByListenerRequiresModernPSK(t *testing.T) {
server := NewServer()
listener := newInMemoryListener()
defer listener.Close()
err := server.ListenByListener(listener)
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired)
}
}
func TestListenByListenerWithConfiguredSecurity(t *testing.T) {
server := NewServer().(*ServerCommon)
listener := newInMemoryListener()
defer listener.Close()
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
if err := server.ListenByListener(listener); err != nil {
t.Fatalf("ListenByListener failed: %v", err)
}
if !server.Status().Alive {
t.Fatal("server should be alive after ListenByListener")
}
if err := server.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestListenByListenerRejectsNil(t *testing.T) {
server := NewServer().(*ServerCommon)
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
err := server.ListenByListener(nil)
if err == nil || err.Error() != "listener is nil" {
t.Fatalf("ListenByListener nil error = %v, want listener is nil", err)
}
}
func TestClientReadMessagePreservesUserStopReason(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.conn = left
client.stopCtx = stopCtx
client.stopFn = stopFn
client.markSessionStarted()
done := make(chan struct{})
go func() {
client.readMessage()
close(done)
}()
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessage should exit after user stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
t.Fatalf("unexpected status after user stop: %+v", status)
}
}
func TestClientReadMessagePreservesServerStopReason(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.conn = left
client.stopCtx = stopCtx
client.stopFn = stopFn
client.markSessionStarted()
done := make(chan struct{})
go func() {
client.readMessage()
close(done)
}()
client.stopClientSessionFromServer("recv stop signal from server", nil)
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessage should exit after server stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
t.Fatalf("unexpected status after server stop: %+v", status)
}
}
func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
client.stopClientSessionFromServer("recv stop signal from server", nil)
if client.shouldSayGoodByeOnStop() {
t.Fatal("server stop should disable goodbye on stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
t.Fatalf("unexpected status after server stop helper: %+v", status)
}
}
func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
client.stopClientSession("recv stop signal from user", nil)
if !client.shouldSayGoodByeOnStop() {
t.Fatal("local stop should keep goodbye enabled")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
t.Fatalf("unexpected status after local stop helper: %+v", status)
}
}
func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
defer right.Close()
loopCtx, loopCancel := context.WithCancel(context.Background())
loopCancel()
client.stopCtx = context.Background()
client.conn = nil
done := make(chan struct{})
go func() {
client.readMessageLoop(loopCtx, left, nil, 1)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessageLoop should exit when provided stopCtx is canceled")
}
if _, err := right.Write([]byte("x")); err == nil {
t.Fatal("peer conn should be closed when loop exits")
}
}