package stario import ( "bytes" "context" "errors" "io" "testing" "time" ) func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) if len(frame) != queHeaderSize+5 { t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5) } if !bytes.Equal(frame[:queMagicSize], queMagic) { t.Fatalf("unexpected magic: %v", frame[:queMagicSize]) } if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 { t.Fatalf("unexpected payload length: got %d want 5", got) } if frame[12] != queVersionV1 { t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1) } if frame[13] != queSupportedFlags { t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags) } if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) { t.Fatalf("unexpected payload: %q", frame[queHeaderSize:]) } } func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) { que := NewQueue() want := que.BuildMessage([]byte("hello")) var buf bytes.Buffer if err := que.WriteFrame(&buf, []byte("hello")); err != nil { t.Fatalf("WriteFrame failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFrame mismatch: got %v want %v", got, want) } } func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) { que := NewQueue() want := que.BuildMessage([]byte("hello")) var buf bytes.Buffer if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil { t.Fatalf("WriteFrameBuffers failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want) } } func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) { que := NewQueue() payloads := [][]byte{ []byte("hello"), []byte("world"), []byte("batch"), } var want []byte for _, payload := range payloads { want = append(want, que.BuildMessage(payload)...) } var buf bytes.Buffer if err := que.WriteFramesBuffers(&buf, payloads...); err != nil { t.Fatalf("WriteFramesBuffers failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want) } } func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) { que := NewQueue() que.Encode = true que.EncodeFunc = bytes.ToUpper want := que.BuildMessage([]byte("hello")) var buf bytes.Buffer if err := que.WriteFrame(&buf, []byte("hello")); err != nil { t.Fatalf("WriteFrame failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want) } } func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) { que := NewQueue() que.Encode = true que.EncodeFunc = bytes.ToUpper want := que.BuildMessage([]byte("hello")) var buf bytes.Buffer if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil { t.Fatalf("WriteFrameBuffers failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want) } } func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) { que := NewQueue() que.Encode = true que.EncodeFunc = bytes.ToUpper payloads := [][]byte{ []byte("hello"), []byte("batch"), } var want []byte for _, payload := range payloads { want = append(want, que.BuildMessage(payload)...) } var buf bytes.Buffer if err := que.WriteFramesBuffers(&buf, payloads...); err != nil { t.Fatalf("WriteFramesBuffers failed: %v", err) } if got := buf.Bytes(); !bytes.Equal(got, want) { t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want) } } func TestQueueParseMessageSplitAcrossCalls(t *testing.T) { que := NewQueueWithCount(1) frame := que.BuildMessage([]byte("hello")) if err := que.ParseMessage(frame[:3], "split"); err != nil { t.Fatalf("unexpected error on partial magic: %v", err) } if err := que.ParseMessage(frame[3:11], "split"); err != nil { t.Fatalf("unexpected error on partial header: %v", err) } if err := que.ParseMessage(frame[11:], "split"); err != nil { t.Fatalf("unexpected error on payload completion: %v", err) } select { case data := <-que.RestoreChan(): if data.ID != 0 { t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID) } if data.Conn != "split" { t.Fatalf("unexpected conn: %#v", data.Conn) } if !bytes.Equal(data.Msg, []byte("hello")) { t.Fatalf("unexpected payload: %q", data.Msg) } case <-time.After(200 * time.Millisecond): t.Fatal("did not restore parsed frame") } } func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) var got [][]byte handler := func(view FrameView) error { got = append(got, cloneBytes(view.Payload)) if view.Conn != "split-view" { t.Fatalf("unexpected conn: %#v", view.Conn) } return nil } if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil { t.Fatalf("unexpected error on partial magic: %v", err) } if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil { t.Fatalf("unexpected error on partial header: %v", err) } if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil { t.Fatalf("unexpected error on payload completion: %v", err) } if len(got) != 1 { t.Fatalf("unexpected frame count: got %d want 1", len(got)) } if !bytes.Equal(got[0], []byte("hello")) { t.Fatalf("unexpected payload: %q", got[0]) } } func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) var got []MsgQueue handler := func(msg MsgQueue) error { got = append(got, MsgQueue{ Msg: cloneBytes(msg.Msg), Conn: msg.Conn, }) return nil } if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil { t.Fatalf("unexpected error on partial magic: %v", err) } if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil { t.Fatalf("unexpected error on partial header: %v", err) } if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil { t.Fatalf("unexpected error on payload completion: %v", err) } if len(got) != 1 { t.Fatalf("unexpected frame count: got %d want 1", len(got)) } if got[0].Conn != "split-owned" { t.Fatalf("unexpected conn: %#v", got[0].Conn) } if !bytes.Equal(got[0].Msg, []byte("hello")) { t.Fatalf("unexpected payload: %q", got[0].Msg) } select { case msg := <-que.RestoreChan(): t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg) default: } } func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) { que := NewQueueWithCount(1) frame := que.BuildMessage([]byte("hello")) err := que.ParseMessage(append([]byte("junk"), frame...), "garbage") if !errors.Is(err, ErrQueueDataFormat) { t.Fatalf("expected data format error, got %v", err) } select { case data := <-que.RestoreChan(): if !bytes.Equal(data.Msg, []byte("hello")) { t.Fatalf("unexpected payload after resync: %q", data.Msg) } case <-time.After(200 * time.Millisecond): t.Fatal("did not restore frame after skipping garbage") } } func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) var got []byte err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error { got = cloneBytes(view.Payload) return nil }) if !errors.Is(err, ErrQueueDataFormat) { t.Fatalf("expected data format error, got %v", err) } if !bytes.Equal(got, []byte("hello")) { t.Fatalf("unexpected payload after resync: %q", got) } } func TestQueueParseMessageViewNilHandler(t *testing.T) { que := NewQueue() err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil) if !errors.Is(err, ErrQueueFrameHandlerNil) { t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil) } } func TestQueueParseMessageOwnedNilHandler(t *testing.T) { que := NewQueue() err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil) if !errors.Is(err, ErrQueueFrameHandlerNil) { t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil) } } func TestQueueWriteFrameNilWriter(t *testing.T) { que := NewQueue() err := que.WriteFrame(nil, []byte("hello")) if !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe) } } func TestQueueWriteFrameBuffersNilWriter(t *testing.T) { que := NewQueue() err := que.WriteFrameBuffers(nil, []byte("hello")) if !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe) } } func TestQueueWriteFramesBuffersNilWriter(t *testing.T) { que := NewQueue() err := que.WriteFramesBuffers(nil, []byte("hello")) if !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe) } } func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) { que := NewQueueWithCount(1) frame := que.BuildMessage([]byte("hello")) frame[12] = 2 err := que.ParseMessage(frame, "version") if !errors.Is(err, ErrQueueUnsupportedVersion) { t.Fatalf("expected unsupported version error, got %v", err) } select { case data := <-que.RestoreChan(): t.Fatalf("unexpected restored frame: %#v", data) default: } } func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) { que := NewQueueCtx(nil, 1, 4) frame := que.BuildMessage([]byte("hello")) err := que.ParseMessage(frame, "large") if !errors.Is(err, ErrQueueMessageTooLarge) { t.Fatalf("expected message too large error, got %v", err) } select { case data := <-que.RestoreChan(): t.Fatalf("unexpected restored frame: %#v", data) default: } } func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) err := que.ParseMessage(frame, []byte("not-comparable")) if !errors.Is(err, ErrQueueConnKeyInvalid) { t.Fatalf("expected invalid conn key error, got %v", err) } } func TestQueueParseMessageRejectsNilConnKey(t *testing.T) { que := NewQueue() frame := que.BuildMessage([]byte("hello")) err := que.ParseMessage(frame, nil) if !errors.Is(err, ErrQueueConnKeyNil) { t.Fatalf("expected nil conn key error, got %v", err) } } func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) { _, err := payloadSizeToUint32(uint64(^uint32(0)) + 1) if !errors.Is(err, ErrQueueMessageTooLarge) { t.Fatalf("expected message too large error, got %v", err) } } func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) { que := NewQueueWithCount(1) que.RestoreDuration(0) type restoreResult struct { msg MsgQueue err error } resultCh := make(chan restoreResult, 1) go func() { msg, err := que.Restore() resultCh <- restoreResult{msg: msg, err: err} }() select { case result := <-resultCh: t.Fatalf("Restore returned too early: %#v", result) case <-time.After(50 * time.Millisecond): } if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil { t.Fatalf("ParseMessage failed: %v", err) } select { case result := <-resultCh: if result.err != nil { t.Fatalf("Restore returned error: %v", result.err) } if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) { t.Fatalf("unexpected restore result: %#v", result.msg) } case <-time.After(200 * time.Millisecond): t.Fatal("Restore did not return after message arrival") } } func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) { que := NewQueue() resultCh := make(chan error, 1) go func() { _, err := que.Restore() resultCh <- err }() que.Stop() select { case err := <-resultCh: if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } case <-time.After(200 * time.Millisecond): t.Fatal("Restore did not return after Stop") } } func TestQueueRestoreChanClosesOnStop(t *testing.T) { que := NewQueue() resultCh := make(chan bool, 1) go func() { _, ok := <-que.RestoreChan() resultCh <- ok }() que.Stop() select { case ok := <-resultCh: if ok { t.Fatal("expected RestoreChan to close after Stop") } case <-time.After(200 * time.Millisecond): t.Fatal("RestoreChan did not close after Stop") } } func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) que := NewQueueCtx(ctx, 1, 0) resultCh := make(chan bool, 1) go func() { _, ok := <-que.RestoreChan() resultCh <- ok }() cancel() select { case ok := <-resultCh: if ok { t.Fatal("expected RestoreChan to close after context cancel") } case <-time.After(200 * time.Millisecond): t.Fatal("RestoreChan did not close after context cancel") } } func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) { que := NewQueueWithCount(1) if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil { t.Fatalf("ParseMessage first failed: %v", err) } errCh := make(chan error, 1) go func() { errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full") }() select { case err := <-errCh: t.Fatalf("ParseMessage returned before Stop: %v", err) case <-time.After(50 * time.Millisecond): } que.Stop() select { case err := <-errCh: if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } case <-time.After(200 * time.Millisecond): t.Fatal("ParseMessage did not return after Stop") } } func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) { que := NewQueue() reentered := false err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error { if !bytes.Equal(view.Payload, []byte("outer")) { t.Fatalf("unexpected outer payload: %q", view.Payload) } return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error { reentered = true if !bytes.Equal(inner.Payload, []byte("inner")) { t.Fatalf("unexpected inner payload: %q", inner.Payload) } return nil }) }) if err != nil { t.Fatalf("ParseMessageView failed: %v", err) } if !reentered { t.Fatal("expected reentrant ParseMessageView to run") } }