package notify import ( "b612.me/stario" "context" "errors" "math" "net" "testing" "time" ) func TestServerStopClientSessionMarksStoppedAndCleansScopedState(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client, _, _ := newRegisteredServerClientForTest(t, server, "client-a", nil, stopCtx, stopFn) scope := serverFileScope(client) fileWait := server.getFileAckPool().prepare(scope, "file-1", "end", 0) signalWait := server.getSignalAckPool().prepare(scope, 1001) cache := server.getReceivedSignalCache() if seen := cache.seenOrRemember(scope, 1001); seen { t.Fatal("first seenOrRemember should report unseen signal") } server.stopClientSession(client, "manual stop", nil) if status := client.Status(); status.Alive || status.Reason != "manual stop" || status.Err != nil { t.Fatalf("unexpected client status after stop: %+v", status) } select { case <-client.StopMonitorChan(): default: t.Fatal("client stop context should be closed") } if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { t.Fatalf("file waiter cancel mismatch: %v", err) } if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { t.Fatalf("signal waiter cancel mismatch: %v", err) } if seen := cache.seenOrRemember(scope, 1001); seen { t.Fatal("received signal cache should be cleared for removed client scope") } if got := server.GetLogicalConn(client.ClientID); got != nil { t.Fatalf("logical should be removed from registry, got %+v", got) } } func TestServerStopLogicalSessionMarksStoppedAndCleansScopedState(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client, _, _ := newRegisteredServerClientForTest(t, server, "client-logical-stop", nil, stopCtx, stopFn) logical := client.LogicalConn() scope := serverFileScope(client) fileWait := server.getFileAckPool().prepare(scope, "file-logical-stop", "end", 0) signalWait := server.getSignalAckPool().prepare(scope, 1002) cache := server.getReceivedSignalCache() if seen := cache.seenOrRemember(scope, 1002); seen { t.Fatal("first seenOrRemember should report unseen signal") } server.stopLogicalSession(logical, "logical stop", nil) if status := client.Status(); status.Alive || status.Reason != "logical stop" || status.Err != nil { t.Fatalf("unexpected client status after logical stop: %+v", status) } select { case <-client.StopMonitorChan(): default: t.Fatal("client stop context should be closed") } if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { t.Fatalf("file waiter cancel mismatch: %v", err) } if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { t.Fatalf("signal waiter cancel mismatch: %v", err) } if seen := cache.seenOrRemember(scope, 1002); seen { t.Fatal("received signal cache should be cleared for removed logical scope") } if got := server.GetLogicalConn(client.ClientID); got != nil { t.Fatalf("logical should be removed from registry, got %+v", got) } } func TestServerStopClientSessionResetsScopedBulkWithServiceShutdown(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop-bulk", nil, stopCtx, stopFn) scope := serverFileScope(client) bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ BulkID: "client-stop-bulk", DataID: 1, Range: BulkRange{Length: 1}, }, 0, client.LogicalConn(), nil, 0, nil, nil, nil, nil, nil) if err := server.getBulkRuntime().register(scope, bulk); err != nil { t.Fatalf("register bulk failed: %v", err) } server.stopClientSession(client, "manual stop", nil) if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { t.Fatalf("stopped client bulk error = %v, want %v", err, errServiceShutdown) } if _, ok := server.getBulkRuntime().lookup(scope, bulk.ID()); ok { t.Fatal("stopped client bulk should be removed from runtime") } } func TestServerStopLogicalSessionResetsScopedIOWithServiceShutdown(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-io", nil, stopCtx, stopFn) scope := serverFileScope(logical) stream := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{ StreamID: "logical-stop-stream", Channel: StreamDataChannel, }, 0, logical, nil, 0, nil, nil, nil, defaultStreamConfig()) if err := server.getStreamRuntime().register(scope, stream); err != nil { t.Fatalf("register stream failed: %v", err) } bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ BulkID: "logical-stop-bulk", DataID: 1, Range: BulkRange{Length: 1}, }, 0, logical, nil, 0, nil, nil, nil, nil, nil) if err := server.getBulkRuntime().register(scope, bulk); err != nil { t.Fatalf("register bulk failed: %v", err) } server.stopLogicalSession(logical, "logical stop", nil) if err := readStreamError(t, stream, time.Second); !errors.Is(err, errServiceShutdown) { t.Fatalf("stopped logical stream error = %v, want %v", err, errServiceShutdown) } if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { t.Fatalf("stopped logical bulk error = %v, want %v", err, errServiceShutdown) } if _, ok := server.getStreamRuntime().lookup(scope, stream.ID()); ok { t.Fatal("stopped logical stream should be removed from runtime") } if _, ok := server.getBulkRuntime().lookup(scope, bulk.ID()); ok { t.Fatal("stopped logical bulk should be removed from runtime") } } func TestServerStopClientSessionIsSafeToRepeat(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client, _, _ := newRegisteredServerClientForTest(t, server, "client-repeat", nil, stopCtx, stopFn) server.stopClientSession(client, "first stop", nil) server.stopClientSession(client, "second stop", nil) if status := client.Status(); status.Alive || status.Reason != "second stop" { t.Fatalf("unexpected repeated stop status: %+v", status) } if got := server.GetLogicalConn(client.ClientID); got != nil { t.Fatalf("logical should stay removed after repeated stop, got %+v", got) } } func TestServerBootstrapAcceptedLogicalRegistersStartedSession(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) server.maxReadTimeout = 2 server.maxWriteTimeout = 3 left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("client-stream", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } if got := server.GetLogicalConn(logical.ClientID); got != logical { t.Fatal("bootstrapAcceptedLogical should register logical in registry") } client := clientConnFromLogical(logical) if !client.Status().Alive { t.Fatalf("client should start alive: %+v", client.Status()) } if got := client.clientConnTransportSnapshot(); got != left { t.Fatal("client tuConn should match accepted stream conn") } if client.ClientAddr == nil || client.ClientAddr.String() != left.RemoteAddr().String() { t.Fatalf("client addr mismatch: got %v want %v", client.ClientAddr, left.RemoteAddr()) } if client.clientConnStopContextSnapshot() == nil || client.clientConnStopFuncSnapshot() == nil { t.Fatal("client stop context should be initialized") } if got, want := client.clientConnMaxReadTimeoutSnapshot(), server.maxReadTimeout; got != want { t.Fatalf("maxReadTimeout mismatch: got %v want %v", got, want) } if got, want := client.clientConnMaxWriteTimeoutSnapshot(), server.maxWriteTimeout; got != want { t.Fatalf("maxWriteTimeout mismatch: got %v want %v", got, want) } if string(client.GetSecretKey()) != string(server.GetSecretKey()) { t.Fatal("client secret key should inherit server transport key") } } func TestServerBootstrapAcceptedLogicalSupportsPacketClient(t *testing.T) { server := NewServer().(*ServerCommon) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:12345") if err != nil { t.Fatalf("ResolveUDPAddr failed: %v", err) } logical := server.bootstrapAcceptedLogical("client-udp", addr, nil) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } client := clientConnFromLogical(logical) if got := client.clientConnTransportSnapshot(); got != nil { t.Fatal("packet client should not keep stream conn") } if client.ClientAddr == nil || client.ClientAddr.String() != addr.String() { t.Fatalf("packet client addr mismatch: got %v want %v", client.ClientAddr, addr) } if got := server.GetLogicalConn(logical.ClientID); got != logical { t.Fatal("packet logical should be registered in registry") } if !client.Status().Alive { t.Fatalf("packet client should start alive: %+v", client.Status()) } } func TestServerAttachAcceptedLogicalTransportRebindsExistingPeer(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: queue, }) oldLeft, oldRight := net.Pipe() defer oldRight.Close() logical := server.bootstrapAcceptedLogical("client-reattach", nil, oldLeft) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } client := clientConnFromLogical(logical) newLeft, newRight := net.Pipe() defer newRight.Close() if err := server.attachAcceptedLogicalTransport(logical, newLeft.RemoteAddr(), newLeft); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } if got := server.GetLogicalConn(logical.ClientID); got != logical { t.Fatal("reattached logical should remain registered in registry") } if got := client.clientConnTransportSnapshot(); got != newLeft { t.Fatal("client transport snapshot should switch to new conn") } if client.ClientAddr == nil || client.ClientAddr.String() != newLeft.RemoteAddr().String() { t.Fatalf("client addr mismatch after attach: got %v want %v", client.ClientAddr, newLeft.RemoteAddr()) } wire := queue.BuildMessage([]byte("reattached-peer")) if _, err := newRight.Write(wire); err != nil { t.Fatalf("new transport write failed: %v", err) } select { case msg := <-queue.RestoreChan(): source := assertServerInboundQueueSource(t, msg.Conn, logical) if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want { t.Fatalf("queue transport generation mismatch: got %d want %d", got, want) } if got, want := string(msg.Msg), "reattached-peer"; got != want { t.Fatalf("queue payload mismatch: got %q want %q", got, want) } case <-time.After(time.Second): t.Fatal("reattached peer did not push framed message") } } func TestServerUpsertAcceptedLogicalReusesExistingPeerByID(t *testing.T) { server := NewServer().(*ServerCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: queue, }) oldLeft, oldRight := net.Pipe() defer oldRight.Close() initial, reused, err := server.upsertAcceptedLogical("client-upsert", nil, oldLeft) if err != nil { t.Fatalf("first upsertAcceptedLogical failed: %v", err) } if reused { t.Fatal("first upsertAcceptedLogical should create, not reuse") } newLeft, newRight := net.Pipe() defer newRight.Close() reattached, reused, err := server.upsertAcceptedLogical("client-upsert", newLeft.RemoteAddr(), newLeft) if err != nil { t.Fatalf("second upsertAcceptedLogical failed: %v", err) } if !reused { t.Fatal("second upsertAcceptedLogical should reuse existing peer") } if reattached != initial { t.Fatal("upsertAcceptedLogical should return the existing logical when ids match") } if got := server.GetLogicalConn("client-upsert"); got != initial { t.Fatal("logical registry should still point at reused peer") } wire := queue.BuildMessage([]byte("upsert-reused")) if _, err := newRight.Write(wire); err != nil { t.Fatalf("new transport write failed: %v", err) } select { case msg := <-queue.RestoreChan(): source := assertServerInboundQueueSource(t, msg.Conn, initial) if got, want := source.TransportGeneration, initial.clientConnTransportGenerationSnapshot(); got != want { t.Fatalf("queue transport generation mismatch: got %d want %d", got, want) } if got, want := string(msg.Msg), "upsert-reused"; got != want { t.Fatalf("queue payload mismatch: got %q want %q", got, want) } case <-time.After(time.Second): t.Fatal("upsertAcceptedClient reused peer did not push framed message") } } func TestServerTransportScopedWaitsSwitchGenerationOnReattach(t *testing.T) { server := NewServer().(*ServerCommon) _, stopFn := context.WithCancel(context.Background()) defer stopFn() firstLeft, firstRight := net.Pipe() defer firstRight.Close() logical := server.bootstrapAcceptedLogical("client-transport-scope", nil, firstLeft) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } client := clientConnFromLogical(logical) client.markClientConnIdentityBound() scope1 := serverTransportScope(client) pending1 := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 8801, Type: MSG_SYNC_ASK}, scope1) fileWait1 := server.getFileAckPool().prepare(scope1, "file-transport-scope", "end", 0) signalWait1 := server.getSignalAckPool().prepare(scope1, 8802) server.detachClientSessionTransport(client, "read error", nil) select { case _, ok := <-pending1.Reply: if ok { t.Fatal("pending wait from detached generation should be canceled") } default: t.Fatal("pending wait from detached generation should close immediately") } if err := server.getFileAckPool().waitPrepared(fileWait1, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { t.Fatalf("file waiter from detached generation cancel mismatch: %v", err) } if err := server.getSignalAckPool().waitPrepared(signalWait1, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { t.Fatalf("signal waiter from detached generation cancel mismatch: %v", err) } secondLeft, secondRight := net.Pipe() defer secondRight.Close() if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } scope2 := serverTransportScope(client) if scope2 == scope1 { t.Fatalf("transport scope should change after reattach: got %q", scope2) } pending2 := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 8803, Type: MSG_SYNC_ASK}, scope2) if server.getPendingWaitPool().deliverWithScopes(8803, []string{scope1}, Message{TransferMsg: TransferMsg{ID: 8803}}) { t.Fatal("stale generation pending reply should not match new scoped wait") } if !server.getPendingWaitPool().deliverWithScopes(8803, serverTransportDeliveryScopes(client), Message{TransferMsg: TransferMsg{ID: 8803}}) { t.Fatal("current generation pending reply should match scoped wait") } select { case _, ok := <-pending2.Reply: if !ok { t.Fatal("pending wait reply channel should remain open long enough to read reply") } case <-time.After(time.Second): t.Fatal("current generation pending wait should receive reply") } fileWait2 := server.getFileAckPool().prepare(scope2, "file-transport-scope", "end", 0) if server.getFileAckPool().deliver(scope1, FileEvent{Packet: FilePacket{FileID: "file-transport-scope", Stage: "end"}}) { t.Fatal("stale generation file ack should not match new scoped wait") } if !server.getFileAckPool().deliverAny(serverTransportDeliveryScopes(client), FileEvent{Packet: FilePacket{FileID: "file-transport-scope", Stage: "end"}}) { t.Fatal("current generation file ack should match scoped wait") } if err := server.getFileAckPool().waitPrepared(fileWait2, defaultFileAckTimeout); err != nil { t.Fatalf("current generation file waiter should succeed: %v", err) } signalWait2 := server.getSignalAckPool().prepare(scope2, 8804) if server.getSignalAckPool().deliver(scope1, 8804) { t.Fatal("stale generation signal ack should not match new scoped wait") } if !server.getSignalAckPool().deliverAny(serverTransportDeliveryScopes(client), 8804) { t.Fatal("current generation signal ack should match scoped wait") } if err := server.getSignalAckPool().waitPrepared(signalWait2, defaultSignalAckTimeout); err != nil { t.Fatalf("current generation signal waiter should succeed: %v", err) } }