package notify import ( "b612.me/stario" "context" "math" "net" "testing" "time" ) func assertServerInboundQueueSource(t *testing.T, raw interface{}, peer any) serverInboundSource { t.Helper() source, ok := raw.(serverInboundSource) if !ok { t.Fatalf("queue source type = %T, want serverInboundSource", raw) } if source.Logical != logicalConnFromPeer(peer) { t.Fatalf("queue source logical mismatch: got %+v want %+v", source.Logical, peer) } if source.Source == "" { t.Fatal("queue source should expose stable source string") } return source } func TestResolveInboundSourcePreservesStaleTransportGeneration(t *testing.T) { server := NewServer().(*ServerCommon) firstLeft, firstRight := net.Pipe() defer firstRight.Close() logical := server.bootstrapAcceptedLogical("inbound-stale-source", nil, firstLeft) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } rt := logical.clientConnSessionRuntimeSnapshot() if rt == nil { t.Fatal("logical runtime should exist") } source := newServerInboundSource(logical, firstLeft, nil, rt.transportGeneration) secondLeft, secondRight := net.Pipe() defer secondLeft.Close() defer secondRight.Close() if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } resolved, transport := server.resolveInboundSource(source) if resolved != logical { t.Fatal("resolveInboundSource should return original logical client") } if transport == nil { t.Fatal("resolveInboundSource should reconstruct transport snapshot") } if got, want := transport.TransportGeneration(), source.TransportGeneration; got != want { t.Fatalf("transport generation = %d, want %d", got, want) } if transport.IsCurrent() { t.Fatal("stale inbound transport should not be current after reattach") } if transport.Attached() { t.Fatal("stale inbound transport should not remain attached after reattach") } if !transport.HasRuntimeConn() { t.Fatal("stale inbound stream transport should keep runtime conn marker") } } func TestResolveInboundSourceRebindsHandedOffConnToCurrentLogical(t *testing.T) { server := NewServer().(*ServerCommon) dstLeft, dstRight := net.Pipe() defer dstRight.Close() dst := server.bootstrapAcceptedLogical("inbound-handoff-dst", nil, dstLeft) if dst == nil { t.Fatal("bootstrapAcceptedLogical(dst) should return logical") } srcLeft, srcRight := net.Pipe() defer srcRight.Close() src := server.bootstrapAcceptedLogical("inbound-handoff-src", nil, srcLeft) if src == nil { t.Fatal("bootstrapAcceptedLogical(src) should return logical") } rt := src.clientConnSessionRuntimeSnapshot() if rt == nil { t.Fatal("source runtime should exist") } source := newServerInboundSource(src, srcLeft, nil, rt.transportGeneration) if err := server.handoffAcceptedLogicalTransport(dst, src); err != nil { t.Fatalf("handoffAcceptedLogicalTransport failed: %v", err) } resolved, transport := server.resolveInboundSource(source) if resolved != dst { t.Fatalf("resolveInboundSource should rebind to current logical: got %+v want %+v", resolved, dst) } if transport == nil { t.Fatal("resolveInboundSource should reconstruct transport snapshot") } if transport.IsCurrent() { t.Fatal("queued inbound source from pre-handoff generation should remain stale") } } func TestResolveLogicalBySourceReturnsNilOnAmbiguousAddress(t *testing.T) { server := NewServer().(*ServerCommon) addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 32001} first := server.bootstrapAcceptedLogical("ambiguous-first", addr, nil) second := server.bootstrapAcceptedLogical("ambiguous-second", addr, nil) if first == nil || second == nil { t.Fatal("bootstrapAcceptedLogical should create both logical peers") } if got := server.resolveLogicalBySource(addr.String()); got != nil { t.Fatalf("resolveLogicalBySource should reject ambiguous addr match, got %+v", got) } } func TestNewServerInboundSourcePrefersLogicalIDOverRemoteAddr(t *testing.T) { server := NewServer().(*ServerCommon) addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 32002} logical := server.bootstrapAcceptedLogical("inbound-source-logical-id", addr, nil) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } source := newServerInboundSource(logical, nil, addr, 0) if got, want := source.Source, logical.ID(); got != want { t.Fatalf("source.Source = %q, want %q", got, want) } resolved, transport := server.resolveInboundSource(source.Source) if resolved != logical { t.Fatalf("resolveInboundSource by logical source = %+v, want %+v", resolved, logical) } if transport != nil { t.Fatalf("resolveInboundSource by logical source transport = %+v, want nil", transport) } } func TestServerDispatchEnvelopePreservesExplicitInboundTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) firstLeft, firstRight := net.Pipe() defer firstRight.Close() logical := server.bootstrapAcceptedLogical("inbound-dispatch", nil, firstLeft) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } rt := logical.clientConnSessionRuntimeSnapshot() if rt == nil { t.Fatal("logical runtime should exist") } staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, rt.transportGeneration, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } secondLeft, secondRight := net.Pipe() defer secondLeft.Close() defer secondRight.Close() if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } gotCh := make(chan Message, 1) server.SetLink("inbound-explicit", func(msg *Message) { gotCh <- *msg }) env, err := wrapTransferMsgEnvelope(TransferMsg{ ID: 7, Key: "inbound-explicit", Value: MsgVal("payload"), Type: MSG_ASYNC, }, server.sequenceEn) if err != nil { t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) } server.dispatchEnvelope(logical, staleTransport, firstLeft, env, time.Now()) select { case msg := <-gotCh: if msg.LogicalConn != logical { t.Fatal("message logical conn mismatch") } if msg.TransportConn == nil { t.Fatal("message transport conn should be preserved") } if got, want := msg.TransportConn.TransportGeneration(), staleTransport.TransportGeneration(); got != want { t.Fatalf("message transport generation = %d, want %d", got, want) } if msg.TransportConn.IsCurrent() { t.Fatal("message transport should stay stale instead of being backfilled to current") } case <-time.After(time.Second): t.Fatal("timed out waiting for dispatched message") } } func TestServerDispatchFileAckUsesExplicitInboundTransportScope(t *testing.T) { server := NewServer().(*ServerCommon) firstLeft, firstRight := net.Pipe() defer firstRight.Close() logical := server.bootstrapAcceptedLogical("inbound-file-ack", nil, firstLeft) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } rt := logical.clientConnSessionRuntimeSnapshot() if rt == nil { t.Fatal("logical runtime should exist") } staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, rt.transportGeneration, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } secondLeft, secondRight := net.Pipe() defer secondLeft.Close() defer secondRight.Close() if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } currentTransport := logical.CurrentTransportConn() if currentTransport == nil { t.Fatal("current transport snapshot should exist") } waitOld := server.getFileAckPool().prepare(serverTransportScopeForTransport(staleTransport), "file-1", "end", 0) waitCurrent := server.getFileAckPool().prepare(serverTransportScopeForTransport(currentTransport), "file-1", "end", 0) server.dispatchFileEnvelope(logical, staleTransport, firstLeft, newFileAckEnvelope("file-1", "end", 0, ""), time.Now()) if err := server.getFileAckPool().waitPrepared(waitOld, defaultFileAckTimeout); err != nil { t.Fatalf("old transport scoped ack should succeed: %v", err) } select { case event, ok := <-waitCurrent.reply: t.Fatalf("current transport scoped ack should remain pending, got (%+v, %v)", event, ok) default: } waitCurrent.cancel() } func TestServerPushMessageSourceDispatchesDirectWithRuntimeDispatcher(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) dispatcher := newInboundDispatcher() defer dispatcher.CloseAndWait() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: queue, inboundDispatcher: dispatcher, }) left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("inbound-fast-path", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } currentTransport := logical.CurrentTransportConn() if currentTransport == nil { t.Fatal("current transport snapshot should exist") } gotCh := make(chan Message, 1) server.SetLink("inbound-fast-path", func(msg *Message) { gotCh <- *msg }) env, err := wrapTransferMsgEnvelope(TransferMsg{ ID: 17, Key: "inbound-fast-path", Value: MsgVal("payload"), Type: MSG_ASYNC, }, server.sequenceEn) if err != nil { t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) } wire, err := server.encodeEnvelopeLogical(logical, env) if err != nil { t.Fatalf("encodeEnvelopeLogical failed: %v", err) } source := newServerInboundSource(logical, left, nil, currentTransport.TransportGeneration()) server.pushMessageSource(wire, source) select { case msg := <-gotCh: if msg.LogicalConn != logical { t.Fatal("message logical conn mismatch") } if msg.TransportConn == nil { t.Fatal("message transport conn should be resolved") } if got, want := msg.TransportConn.TransportGeneration(), currentTransport.TransportGeneration(); got != want { t.Fatalf("message transport generation = %d, want %d", got, want) } case <-time.After(time.Second): t.Fatal("timed out waiting for direct push dispatch") } select { case msg := <-queue.RestoreChan(): t.Fatalf("fast path should not enqueue RestoreChan message, got %+v", msg) default: } }