notify/stream_fastpath_test.go

527 lines
15 KiB
Go
Raw Normal View History

package notify
import (
"b612.me/stario"
"context"
"math"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestStreamFastDataFrameRoundTrip(t *testing.T) {
frame, err := encodeStreamFastDataFrame(11, 7, []byte("payload"))
if err != nil {
t.Fatalf("encodeStreamFastDataFrame failed: %v", err)
}
got, matched, err := decodeStreamFastDataFrame(frame)
if err != nil {
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastDataFrame should match fast payload")
}
if got.DataID != 11 {
t.Fatalf("data id = %d, want %d", got.DataID, 11)
}
if got.Seq != 7 {
t.Fatalf("seq = %d, want %d", got.Seq, 7)
}
if string(got.Payload) != "payload" {
t.Fatalf("payload = %q, want %q", got.Payload, "payload")
}
}
func TestStreamFastBatchPlainRoundTrip(t *testing.T) {
frames := []streamFastDataFrame{
{
DataID: 11,
Seq: 7,
Payload: []byte("alpha"),
},
{
DataID: 12,
Seq: 8,
Payload: []byte("beta"),
},
}
wire, err := encodeStreamFastBatchPlain(frames)
if err != nil {
t.Fatalf("encodeStreamFastBatchPlain failed: %v", err)
}
decoded, matched, err := decodeStreamFastBatchPlain(wire)
if err != nil {
t.Fatalf("decodeStreamFastBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastBatchPlain should match encoded batch")
}
if got, want := len(decoded), len(frames); got != want {
t.Fatalf("decoded frame count = %d, want %d", got, want)
}
for index := range frames {
if got, want := decoded[index].DataID, frames[index].DataID; got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := decoded[index].Seq, frames[index].Seq; got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want {
t.Fatalf("frame %d payload = %q, want %q", index, got, want)
}
}
}
func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client",
DataID: 23,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
payload, err := client.encodeFastStreamDataPayload(23, 1, []byte("fast-payload"))
if err != nil {
t.Fatalf("encodeFastStreamDataPayload failed: %v", err)
}
if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil {
t.Fatalf("dispatchInboundTransportPayload failed: %v", err)
}
readStreamExactly(t, stream, "fast-payload", 2*time.Second)
}
func TestClientDispatchInboundTransportPayloadFastStreamBatch(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
streamA := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-a",
DataID: 23,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
streamB := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-b",
DataID: 24,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), streamA); err != nil {
t.Fatalf("register streamA failed: %v", err)
}
if err := runtime.register(clientFileScope(), streamB); err != nil {
t.Fatalf("register streamB failed: %v", err)
}
payload, err := client.encodeFastStreamBatchPayload([]streamFastDataFrame{
{DataID: 23, Seq: 1, Payload: []byte("fast-a")},
{DataID: 24, Seq: 2, Payload: []byte("fast-b")},
})
if err != nil {
t.Fatalf("encodeFastStreamBatchPayload failed: %v", err)
}
if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil {
t.Fatalf("dispatchInboundTransportPayload failed: %v", err)
}
readStreamExactly(t, streamA, "fast-a", 2*time.Second)
readStreamExactly(t, streamB, "fast-b", 2*time.Second)
}
func TestStreamBatchSenderEncodeRequestsCoalescesFastV2Frames(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsFlushesBeforePreEncodedPayload(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return append([]byte("single-"), frame.Payload...), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch-2"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{
{DataID: 101, Seq: 1, Payload: []byte("a")},
{DataID: 102, Seq: 2, Payload: []byte("b")},
},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{
{DataID: 103, Seq: 3, Payload: []byte("c")},
},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := string(payloads[0]), "batch-2"; got != want {
t.Fatalf("first payload = %q, want %q", got, want)
}
if got, want := string(payloads[1]), "raw"; got != want {
t.Fatalf("second payload = %q, want %q", got, want)
}
if got, want := string(payloads[2]), "single-c"; got != want {
t.Fatalf("third payload = %q, want %q", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if singleCalls != 1 {
t.Fatalf("single encode calls = %d, want %d", singleCalls, 1)
}
}
func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
rt := newClientSessionRuntime(nil, stopCtx, stopFn, queue, 1)
client.setClientSessionRuntime(rt)
gotCh := make(chan Message, 1)
client.SetLink("client-fast-dispatch", func(msg *Message) {
gotCh <- *msg
})
env, err := wrapTransferMsgEnvelope(TransferMsg{
ID: 31,
Key: "client-fast-dispatch",
Value: MsgVal("payload"),
Type: MSG_ASYNC,
}, client.sequenceEn)
if err != nil {
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
}
wire, err := client.encodeEnvelope(env)
if err != nil {
t.Fatalf("encodeEnvelope failed: %v", err)
}
if !client.pushMessageFast(queue, wire, rt.inboundDispatcher) {
t.Fatal("pushMessageFast should use direct dispatch")
}
select {
case msg := <-gotCh:
if got, want := msg.Key, "client-fast-dispatch"; got != want {
t.Fatalf("message key = %q, want %q", got, want)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for direct client dispatch")
}
select {
case msg := <-queue.RestoreChan():
t.Fatalf("fast path should not enqueue RestoreChan message, got %+v", msg)
default:
}
}
func TestStreamBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
largePayload := make([]byte, streamFastBatchMaxPlainBytes-streamFastBatchHeaderLen-streamFastBatchItemHeaderLen-128)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: largePayload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 1; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("post-flush batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsUsesAdaptiveSoftLimit(t *testing.T) {
binding := &transportBinding{}
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
singleCalls int
batchCalls int
)
sender := &streamBatchSender{
binding: binding,
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
batchCalls++
return []byte("batch"), nil
},
},
}
payload := make([]byte, 160*1024)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 2; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if batchCalls != 0 {
t.Fatalf("batch encode calls = %d, want 0", batchCalls)
}
}
func TestClientSendFastStreamDataSplitsLargeChunkWhenAdaptiveSoftLimitShrinks(t *testing.T) {
binding := newTransportBinding(&delayedWriteConn{}, stario.NewQueue())
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
mu sync.Mutex
singleFrames []streamFastDataFrame
batchFrames [][]streamFastDataFrame
)
binding.streamSender = newStreamBatchSender(binding, streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
mu.Lock()
singleFrames = append(singleFrames, streamFastDataFrame{
DataID: frame.DataID,
Seq: frame.Seq,
Payload: append([]byte(nil), frame.Payload...),
})
mu.Unlock()
return []byte{1}, nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
for index := range frames {
cloned[index] = streamFastDataFrame{
DataID: frames[index].DataID,
Seq: frames[index].Seq,
Payload: append([]byte(nil), frames[index].Payload...),
}
}
mu.Lock()
batchFrames = append(batchFrames, cloned)
mu.Unlock()
return []byte{2}, nil
},
}, nil)
defer binding.stopBackgroundWorkers()
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
transportAttached: true,
stopCtx: stopCtx,
stopFn: stopFn,
queue: binding.queueSnapshot(),
inboundDispatcher: newInboundDispatcher(),
suppressGoodByeOnStop: &atomic.Bool{},
})
stream := &streamHandle{
dataID: 41,
fastPathVersion: streamFastPathVersionV2,
}
chunk := make([]byte, streamFastBatchDirectLimit+128*1024)
for index := range chunk {
chunk[index] = byte(index)
}
if err := client.sendFastStreamData(context.Background(), stream, chunk); err != nil {
t.Fatalf("sendFastStreamData failed: %v", err)
}
expectedFrames := streamFastSplitFrameCount(len(chunk), streamAdaptiveFramePayloadLimit(binding))
if got, want := int(stream.outboundSeq.Load()), expectedFrames; got != want {
t.Fatalf("reserved seq count = %d, want %d", got, want)
}
mu.Lock()
defer mu.Unlock()
if len(batchFrames) != 0 {
t.Fatalf("batch encode calls = %d, want 0", len(batchFrames))
}
if got, want := len(singleFrames), expectedFrames; got != want {
t.Fatalf("single frame count = %d, want %d", got, want)
}
rebuilt := make([]byte, 0, len(chunk))
for index, frame := range singleFrames {
if got, want := frame.DataID, uint64(41); got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := frame.Seq, uint64(index+1); got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
rebuilt = append(rebuilt, frame.Payload...)
}
if string(rebuilt) != string(chunk) {
t.Fatal("rebuilt payload does not match original chunk")
}
}