package notify import ( "context" "errors" "net" "sync" "sync/atomic" "testing" "time" ) var errInMemoryListenerClosed = errors.New("in-memory listener closed") type inMemoryListener struct { closed chan struct{} once sync.Once } func newInMemoryListener() *inMemoryListener { return &inMemoryListener{ closed: make(chan struct{}), } } func (l *inMemoryListener) Accept() (net.Conn, error) { <-l.closed return nil, errInMemoryListenerClosed } func (l *inMemoryListener) Close() error { l.once.Do(func() { close(l.closed) }) return nil } func (l *inMemoryListener) Addr() net.Addr { return inMemoryAddr("in-memory-listener") } type inMemoryAddr string func (a inMemoryAddr) Network() string { return "in-memory" } func (a inMemoryAddr) String() string { return string(a) } func TestConnectByConnRequiresModernPSK(t *testing.T) { client := NewClient() left, right := net.Pipe() defer left.Close() defer right.Close() err := client.ConnectByConn(left) if !errors.Is(err, errModernPSKRequired) { t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired) } } func TestConnectByConnWithConfiguredSecurity(t *testing.T) { client := NewClient().(*ClientCommon) secret := []byte("0123456789abcdef0123456789abcdef") left, right := net.Pipe() defer right.Close() server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { server.SetSecretKey(secret) }) bootstrapPeerAttachConnForTest(t, server, right) client.SetSecretKey(secret) if err := client.ConnectByConn(left); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("Stop failed: %v", err) } } func TestConnectByFactoryRequiresModernPSK(t *testing.T) { client := NewClient() called := false err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { called = true left, right := net.Pipe() _ = right.Close() return left, nil }) if !errors.Is(err, errModernPSKRequired) { t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired) } if called { t.Fatal("dialFn should not be called before security validation passes") } } func TestConnectByFactoryRejectsNilDialFn(t *testing.T) { client := NewClient().(*ClientCommon) client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) err := client.ConnectByFactory(context.Background(), nil) if err == nil || err.Error() != "dialFn is nil" { t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err) } } func TestConnectByFactoryPropagatesDialError(t *testing.T) { client := NewClient().(*ClientCommon) client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) wantErr := errors.New("dial failed") err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { return nil, wantErr }) if !errors.Is(err, wantErr) { t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr) } } func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) { client := NewClient().(*ClientCommon) secret := []byte("0123456789abcdef0123456789abcdef") left, right := net.Pipe() defer right.Close() server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { server.SetSecretKey(secret) }) bootstrapPeerAttachConnForTest(t, server, right) client.SetSecretKey(secret) if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) { if ctx == nil { t.Fatal("ConnectByFactory should normalize nil context") } return left, nil }); err != nil { t.Fatalf("ConnectByFactory failed: %v", err) } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("Stop failed: %v", err) } } func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) { client := NewClient().(*ClientCommon) client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) ctx, cancel := context.WithCancel(context.Background()) defer cancel() firstDialEntered := make(chan struct{}, 1) firstDone := make(chan error, 1) go func() { firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) { firstDialEntered <- struct{}{} <-ctx.Done() return nil, ctx.Err() }) }() select { case <-firstDialEntered: case <-time.After(time.Second): t.Fatal("first connect attempt did not enter dialFn") } secondDialCalled := false err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { secondDialCalled = true return nil, errors.New("second dial should not run") }) if err == nil || err.Error() != "client already run" { t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err) } if secondDialCalled { t.Fatal("second dialFn should not be called during first connect start") } cancel() select { case err = <-firstDone: case <-time.After(time.Second): t.Fatal("first ConnectByFactory did not finish after cancel") } if !errors.Is(err, context.Canceled) { t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled) } wantErr := errors.New("dial after rollback") err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { return nil, wantErr }) if !errors.Is(err, wantErr) { t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr) } } func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) { client := NewClient().(*ClientCommon) secret := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(secret) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { server.SetSecretKey(secret) }) firstLeft, firstRight := net.Pipe() defer firstRight.Close() bootstrapPeerAttachConnForTest(t, server, firstRight) if err := client.ConnectByConn(firstLeft); err != nil { t.Fatalf("initial ConnectByConn failed: %v", err) } before := client.clientSessionRuntimeSnapshot() if before == nil { t.Fatal("runtime should exist after initial connect") } initialEpoch := before.epoch initialStopCtx := before.stopCtx initialQueue := before.queue client.clearClientSessionRuntimeTransport() recvCh := make(chan Message, 1) client.SetLink("reattach-public", func(message *Message) { recvCh <- *message }) secondLeft, secondRight := net.Pipe() defer secondRight.Close() bootstrapPeerAttachConnForTest(t, server, secondRight) if err := client.ConnectByConn(secondLeft); err != nil { t.Fatalf("reattach ConnectByConn failed: %v", err) } after := client.clientSessionRuntimeSnapshot() if after == nil { t.Fatal("runtime should exist after reattach") } if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached { t.Fatalf("reattached runtime mismatch: %+v", after) } env, err := wrapTransferMsgEnvelope(TransferMsg{ ID: 88, Key: "reattach-public", Value: []byte("ok"), Type: MSG_ASYNC, }, client.sequenceEn) if err != nil { t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) } wire, err := client.encodeEnvelope(env) if err != nil { t.Fatalf("encodeEnvelope failed: %v", err) } if _, err := secondRight.Write(wire); err != nil { t.Fatalf("reattached conn write failed: %v", err) } select { case msg := <-recvCh: if got, want := msg.Key, "reattach-public"; got != want { t.Fatalf("message key mismatch: got %q want %q", got, want) } if got, want := string(msg.Value), "ok"; got != want { t.Fatalf("message value mismatch: got %q want %q", got, want) } case <-time.After(time.Second): t.Fatal("reattached public conn did not dispatch message") } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("final Stop failed: %v", err) } } func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) { client := NewClient().(*ClientCommon) secret := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(secret) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { server.SetSecretKey(secret) }) firstLeft, firstRight := net.Pipe() defer firstRight.Close() bootstrapPeerAttachConnForTest(t, server, firstRight) if err := client.ConnectByConn(firstLeft); err != nil { t.Fatalf("initial ConnectByConn failed: %v", err) } before := client.clientSessionRuntimeSnapshot() if before == nil { t.Fatal("runtime should exist after initial connect") } initialEpoch := before.epoch client.clearClientSessionRuntimeTransport() var dialCount atomic.Int32 secondLeft, secondRight := net.Pipe() defer secondRight.Close() bootstrapPeerAttachConnForTest(t, server, secondRight) if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { dialCount.Add(1) return secondLeft, nil }); err != nil { t.Fatalf("reattach ConnectByFactory failed: %v", err) } if got, want := dialCount.Load(), int32(1); got != want { t.Fatalf("dial count mismatch: got %d want %d", got, want) } after := client.clientSessionRuntimeSnapshot() if after == nil { t.Fatal("runtime should exist after factory reattach") } if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached { t.Fatalf("reattached runtime mismatch: %+v", after) } snapshot, err := GetClientRuntimeSnapshot(client) if err != nil { t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) } if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want { t.Fatalf("connect source mismatch: got %q want %q", got, want) } if !snapshot.CanReconnect { t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot) } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("final Stop failed: %v", err) } } func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) { client := NewClient().(*ClientCommon) UseLegacySecurityClient(client) failErr := errors.New("key exchange fail for test") client.keyExchangeFn = func(Client) error { return failErr } left1, right1 := net.Pipe() defer right1.Close() err := client.ConnectByConn(left1) if !errors.Is(err, failErr) { t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr) } status := client.Status() if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) { t.Fatalf("unexpected status after failed key exchange: %+v", status) } select { case <-client.StopMonitorChan(): t.Fatal("StopMonitorChan should remain open after failed connect cleanup") case <-time.After(20 * time.Millisecond): } client.SetSkipExchangeKey(true) left2, right2 := net.Pipe() defer right2.Close() server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { UseLegacySecurityServer(server) }) bootstrapPeerAttachConnForTest(t, server, right2) if err := client.ConnectByConn(left2); err != nil { t.Fatalf("ConnectByConn second attempt failed: %v", err) } if !client.Status().Alive { t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status()) } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("Stop failed: %v", err) } } func TestListenByListenerRequiresModernPSK(t *testing.T) { server := NewServer() listener := newInMemoryListener() defer listener.Close() err := server.ListenByListener(listener) if !errors.Is(err, errModernPSKRequired) { t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired) } } func TestListenByListenerWithConfiguredSecurity(t *testing.T) { server := NewServer().(*ServerCommon) listener := newInMemoryListener() defer listener.Close() server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) if err := server.ListenByListener(listener); err != nil { t.Fatalf("ListenByListener failed: %v", err) } if !server.Status().Alive { t.Fatal("server should be alive after ListenByListener") } if err := server.Stop(); err != nil { t.Fatalf("Stop failed: %v", err) } } func TestListenByListenerRejectsNil(t *testing.T) { server := NewServer().(*ServerCommon) server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) err := server.ListenByListener(nil) if err == nil || err.Error() != "listener is nil" { t.Fatalf("ListenByListener nil error = %v, want listener is nil", err) } } func TestClientReadMessagePreservesUserStopReason(t *testing.T) { client := NewClient().(*ClientCommon) left, right := net.Pipe() stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client.conn = left client.stopCtx = stopCtx client.stopFn = stopFn client.markSessionStarted() done := make(chan struct{}) go func() { client.readMessage() close(done) }() if err := client.Stop(); err != nil { t.Fatalf("Stop failed: %v", err) } _ = right.Close() select { case <-done: case <-time.After(time.Second): t.Fatal("readMessage should exit after user stop") } status := client.Status() if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil { t.Fatalf("unexpected status after user stop: %+v", status) } } func TestClientReadMessagePreservesServerStopReason(t *testing.T) { client := NewClient().(*ClientCommon) left, right := net.Pipe() stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client.conn = left client.stopCtx = stopCtx client.stopFn = stopFn client.markSessionStarted() done := make(chan struct{}) go func() { client.readMessage() close(done) }() client.stopClientSessionFromServer("recv stop signal from server", nil) _ = right.Close() select { case <-done: case <-time.After(time.Second): t.Fatal("readMessage should exit after server stop") } status := client.Status() if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil { t.Fatalf("unexpected status after server stop: %+v", status) } } func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) { client := NewClient().(*ClientCommon) client.markSessionStarted() client.stopClientSessionFromServer("recv stop signal from server", nil) if client.shouldSayGoodByeOnStop() { t.Fatal("server stop should disable goodbye on stop") } status := client.Status() if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil { t.Fatalf("unexpected status after server stop helper: %+v", status) } } func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) { client := NewClient().(*ClientCommon) client.markSessionStarted() client.stopClientSession("recv stop signal from user", nil) if !client.shouldSayGoodByeOnStop() { t.Fatal("local stop should keep goodbye enabled") } status := client.Status() if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil { t.Fatalf("unexpected status after local stop helper: %+v", status) } } func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) { client := NewClient().(*ClientCommon) left, right := net.Pipe() defer right.Close() loopCtx, loopCancel := context.WithCancel(context.Background()) loopCancel() client.stopCtx = context.Background() client.conn = nil done := make(chan struct{}) go func() { client.readMessageLoop(loopCtx, left, nil, 1) close(done) }() select { case <-done: case <-time.After(time.Second): t.Fatal("readMessageLoop should exit when provided stopCtx is canceled") } if _, err := right.Write([]byte("x")); err == nil { t.Fatal("peer conn should be closed when loop exits") } }