package notify import ( "b612.me/stario" "context" "math" "sync" "sync/atomic" "testing" "time" ) func TestStreamFastDataFrameRoundTrip(t *testing.T) { frame, err := encodeStreamFastDataFrame(11, 7, []byte("payload")) if err != nil { t.Fatalf("encodeStreamFastDataFrame failed: %v", err) } got, matched, err := decodeStreamFastDataFrame(frame) if err != nil { t.Fatalf("decodeStreamFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeStreamFastDataFrame should match fast payload") } if got.DataID != 11 { t.Fatalf("data id = %d, want %d", got.DataID, 11) } if got.Seq != 7 { t.Fatalf("seq = %d, want %d", got.Seq, 7) } if string(got.Payload) != "payload" { t.Fatalf("payload = %q, want %q", got.Payload, "payload") } } func TestStreamFastBatchPlainRoundTrip(t *testing.T) { frames := []streamFastDataFrame{ { DataID: 11, Seq: 7, Payload: []byte("alpha"), }, { DataID: 12, Seq: 8, Payload: []byte("beta"), }, } wire, err := encodeStreamFastBatchPlain(frames) if err != nil { t.Fatalf("encodeStreamFastBatchPlain failed: %v", err) } decoded, matched, err := decodeStreamFastBatchPlain(wire) if err != nil { t.Fatalf("decodeStreamFastBatchPlain failed: %v", err) } if !matched { t.Fatal("decodeStreamFastBatchPlain should match encoded batch") } if got, want := len(decoded), len(frames); got != want { t.Fatalf("decoded frame count = %d, want %d", got, want) } for index := range frames { if got, want := decoded[index].DataID, frames[index].DataID; got != want { t.Fatalf("frame %d data id = %d, want %d", index, got, want) } if got, want := decoded[index].Seq, frames[index].Seq; got != want { t.Fatalf("frame %d seq = %d, want %d", index, got, want) } if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want { t.Fatalf("frame %d payload = %q, want %q", index, got, want) } } } func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) { client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } runtime := client.getStreamRuntime() if runtime == nil { t.Fatal("client stream runtime should not be nil") } stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "fast-client", DataID: 23, Channel: StreamDataChannel, }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) if err := runtime.register(clientFileScope(), stream); err != nil { t.Fatalf("register stream failed: %v", err) } payload, err := client.encodeFastStreamDataPayload(23, 1, []byte("fast-payload")) if err != nil { t.Fatalf("encodeFastStreamDataPayload failed: %v", err) } if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil { t.Fatalf("dispatchInboundTransportPayload failed: %v", err) } readStreamExactly(t, stream, "fast-payload", 2*time.Second) } func TestClientDispatchInboundTransportPayloadFastStreamBatch(t *testing.T) { client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } runtime := client.getStreamRuntime() if runtime == nil { t.Fatal("client stream runtime should not be nil") } streamA := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "fast-client-a", DataID: 23, FastPathVersion: streamFastPathVersionCurrent, Channel: StreamDataChannel, }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) streamB := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ StreamID: "fast-client-b", DataID: 24, FastPathVersion: streamFastPathVersionCurrent, Channel: StreamDataChannel, }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) if err := runtime.register(clientFileScope(), streamA); err != nil { t.Fatalf("register streamA failed: %v", err) } if err := runtime.register(clientFileScope(), streamB); err != nil { t.Fatalf("register streamB failed: %v", err) } payload, err := client.encodeFastStreamBatchPayload([]streamFastDataFrame{ {DataID: 23, Seq: 1, Payload: []byte("fast-a")}, {DataID: 24, Seq: 2, Payload: []byte("fast-b")}, }) if err != nil { t.Fatalf("encodeFastStreamBatchPayload failed: %v", err) } if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil { t.Fatalf("dispatchInboundTransportPayload failed: %v", err) } readStreamExactly(t, streamA, "fast-a", 2*time.Second) readStreamExactly(t, streamB, "fast-b", 2*time.Second) } func TestStreamBatchSenderEncodeRequestsCoalescesFastV2Frames(t *testing.T) { var ( singleCalls int batchCalls [][]streamFastDataFrame ) sender := &streamBatchSender{ codec: streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { singleCalls++ return []byte("single"), nil }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { cloned := make([]streamFastDataFrame, len(frames)) copy(cloned, frames) batchCalls = append(batchCalls, cloned) return []byte("batch"), nil }, }, } payloads, err := sender.encodeRequests([]streamBatchRequest{ { frames: []streamFastDataFrame{{ DataID: 101, Seq: 1, Payload: []byte("a"), }}, fastPathVersion: streamFastPathVersionV2, }, { frames: []streamFastDataFrame{{ DataID: 102, Seq: 2, Payload: []byte("b"), }}, fastPathVersion: streamFastPathVersionV2, }, }) if err != nil { t.Fatalf("encodeRequests failed: %v", err) } if got, want := len(payloads), 1; got != want { t.Fatalf("payload count = %d, want %d", got, want) } if singleCalls != 0 { t.Fatalf("single encode calls = %d, want 0", singleCalls) } if got, want := len(batchCalls), 1; got != want { t.Fatalf("batch encode calls = %d, want %d", got, want) } if got, want := len(batchCalls[0]), 2; got != want { t.Fatalf("batched frame count = %d, want %d", got, want) } } func TestStreamBatchSenderEncodeRequestsFlushesBeforePreEncodedPayload(t *testing.T) { var ( singleCalls int batchCalls [][]streamFastDataFrame ) sender := &streamBatchSender{ codec: streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { singleCalls++ return append([]byte("single-"), frame.Payload...), nil }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { cloned := make([]streamFastDataFrame, len(frames)) copy(cloned, frames) batchCalls = append(batchCalls, cloned) return []byte("batch-2"), nil }, }, } payloads, err := sender.encodeRequests([]streamBatchRequest{ { frames: []streamFastDataFrame{ {DataID: 101, Seq: 1, Payload: []byte("a")}, {DataID: 102, Seq: 2, Payload: []byte("b")}, }, fastPathVersion: streamFastPathVersionV2, }, { encodedPayload: []byte("raw"), hasEncoded: true, fastPathVersion: streamFastPathVersionV2, }, { frames: []streamFastDataFrame{ {DataID: 103, Seq: 3, Payload: []byte("c")}, }, fastPathVersion: streamFastPathVersionV2, }, }) if err != nil { t.Fatalf("encodeRequests failed: %v", err) } if got, want := len(payloads), 3; got != want { t.Fatalf("payload count = %d, want %d", got, want) } if got, want := string(payloads[0]), "batch-2"; got != want { t.Fatalf("first payload = %q, want %q", got, want) } if got, want := string(payloads[1]), "raw"; got != want { t.Fatalf("second payload = %q, want %q", got, want) } if got, want := string(payloads[2]), "single-c"; got != want { t.Fatalf("third payload = %q, want %q", got, want) } if got, want := len(batchCalls), 1; got != want { t.Fatalf("batch encode calls = %d, want %d", got, want) } if singleCalls != 1 { t.Fatalf("single encode calls = %d, want %d", singleCalls, 1) } } func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) { client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) rt := newClientSessionRuntime(nil, stopCtx, stopFn, queue, 1) client.setClientSessionRuntime(rt) gotCh := make(chan Message, 1) client.SetLink("client-fast-dispatch", func(msg *Message) { gotCh <- *msg }) env, err := wrapTransferMsgEnvelope(TransferMsg{ ID: 31, Key: "client-fast-dispatch", Value: MsgVal("payload"), Type: MSG_ASYNC, }, client.sequenceEn) if err != nil { t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) } wire, err := client.encodeEnvelope(env) if err != nil { t.Fatalf("encodeEnvelope failed: %v", err) } if !client.pushMessageFast(queue, wire, rt.inboundDispatcher) { t.Fatal("pushMessageFast should use direct dispatch") } select { case msg := <-gotCh: if got, want := msg.Key, "client-fast-dispatch"; got != want { t.Fatalf("message key = %q, want %q", got, want) } case <-time.After(time.Second): t.Fatal("timed out waiting for direct client dispatch") } select { case msg := <-queue.RestoreChan(): t.Fatalf("fast path should not enqueue RestoreChan message, got %+v", msg) default: } } func TestStreamBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) { var ( singleCalls int batchCalls [][]streamFastDataFrame ) sender := &streamBatchSender{ codec: streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { singleCalls++ return []byte("single"), nil }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { cloned := make([]streamFastDataFrame, len(frames)) copy(cloned, frames) batchCalls = append(batchCalls, cloned) return []byte("batch"), nil }, }, } largePayload := make([]byte, streamFastBatchMaxPlainBytes-streamFastBatchHeaderLen-streamFastBatchItemHeaderLen-128) payloads, err := sender.encodeRequests([]streamBatchRequest{ { frames: []streamFastDataFrame{{ DataID: 101, Seq: 1, Payload: largePayload, }}, fastPathVersion: streamFastPathVersionV2, }, { encodedPayload: []byte("raw"), hasEncoded: true, fastPathVersion: streamFastPathVersionV2, }, { frames: []streamFastDataFrame{{ DataID: 202, Seq: 1, Payload: []byte("a"), }}, fastPathVersion: streamFastPathVersionV2, }, { frames: []streamFastDataFrame{{ DataID: 202, Seq: 2, Payload: []byte("b"), }}, fastPathVersion: streamFastPathVersionV2, }, }) if err != nil { t.Fatalf("encodeRequests failed: %v", err) } if got, want := len(payloads), 3; got != want { t.Fatalf("payload count = %d, want %d", got, want) } if got, want := singleCalls, 1; got != want { t.Fatalf("single encode calls = %d, want %d", got, want) } if got, want := len(batchCalls), 1; got != want { t.Fatalf("batch encode calls = %d, want %d", got, want) } if got, want := len(batchCalls[0]), 2; got != want { t.Fatalf("post-flush batched frame count = %d, want %d", got, want) } } func TestStreamBatchSenderEncodeRequestsUsesAdaptiveSoftLimit(t *testing.T) { binding := &transportBinding{} binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) var ( singleCalls int batchCalls int ) sender := &streamBatchSender{ binding: binding, codec: streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { singleCalls++ return []byte("single"), nil }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { batchCalls++ return []byte("batch"), nil }, }, } payload := make([]byte, 160*1024) payloads, err := sender.encodeRequests([]streamBatchRequest{ { frames: []streamFastDataFrame{{ DataID: 101, Seq: 1, Payload: payload, }}, fastPathVersion: streamFastPathVersionV2, }, { frames: []streamFastDataFrame{{ DataID: 102, Seq: 2, Payload: payload, }}, fastPathVersion: streamFastPathVersionV2, }, }) if err != nil { t.Fatalf("encodeRequests failed: %v", err) } if got, want := len(payloads), 2; got != want { t.Fatalf("payload count = %d, want %d", got, want) } if got, want := singleCalls, 2; got != want { t.Fatalf("single encode calls = %d, want %d", got, want) } if batchCalls != 0 { t.Fatalf("batch encode calls = %d, want 0", batchCalls) } } func TestClientSendFastStreamDataSplitsLargeChunkWhenAdaptiveSoftLimitShrinks(t *testing.T) { binding := newTransportBinding(&delayedWriteConn{}, stario.NewQueue()) binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) var ( mu sync.Mutex singleFrames []streamFastDataFrame batchFrames [][]streamFastDataFrame ) binding.streamSender = newStreamBatchSender(binding, streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { mu.Lock() singleFrames = append(singleFrames, streamFastDataFrame{ DataID: frame.DataID, Seq: frame.Seq, Payload: append([]byte(nil), frame.Payload...), }) mu.Unlock() return []byte{1}, nil }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { cloned := make([]streamFastDataFrame, len(frames)) for index := range frames { cloned[index] = streamFastDataFrame{ DataID: frames[index].DataID, Seq: frames[index].Seq, Payload: append([]byte(nil), frames[index].Payload...), } } mu.Lock() batchFrames = append(batchFrames, cloned) mu.Unlock() return []byte{2}, nil }, }, nil) defer binding.stopBackgroundWorkers() client := NewClient().(*ClientCommon) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() client.setClientSessionRuntime(&clientSessionRuntime{ transport: binding, transportAttached: true, stopCtx: stopCtx, stopFn: stopFn, queue: binding.queueSnapshot(), inboundDispatcher: newInboundDispatcher(), suppressGoodByeOnStop: &atomic.Bool{}, }) stream := &streamHandle{ dataID: 41, fastPathVersion: streamFastPathVersionV2, } chunk := make([]byte, streamFastBatchDirectLimit+128*1024) for index := range chunk { chunk[index] = byte(index) } if err := client.sendFastStreamData(context.Background(), stream, chunk); err != nil { t.Fatalf("sendFastStreamData failed: %v", err) } expectedFrames := streamFastSplitFrameCount(len(chunk), streamAdaptiveFramePayloadLimit(binding)) if got, want := int(stream.outboundSeq.Load()), expectedFrames; got != want { t.Fatalf("reserved seq count = %d, want %d", got, want) } mu.Lock() defer mu.Unlock() if len(batchFrames) != 0 { t.Fatalf("batch encode calls = %d, want 0", len(batchFrames)) } if got, want := len(singleFrames), expectedFrames; got != want { t.Fatalf("single frame count = %d, want %d", got, want) } rebuilt := make([]byte, 0, len(chunk)) for index, frame := range singleFrames { if got, want := frame.DataID, uint64(41); got != want { t.Fatalf("frame %d data id = %d, want %d", index, got, want) } if got, want := frame.Seq, uint64(index+1); got != want { t.Fatalf("frame %d seq = %d, want %d", index, got, want) } rebuilt = append(rebuilt, frame.Payload...) } if string(rebuilt) != string(chunk) { t.Fatal("rebuilt payload does not match original chunk") } }