stario/que_test.go

521 lines
14 KiB
Go
Raw Permalink Normal View History

2024-08-18 17:17:20 +08:00
package stario
import (
"bytes"
"context"
"errors"
"io"
2024-08-18 17:17:20 +08:00
"testing"
"time"
)
func TestQueueBuildMessageUsesVersionedHeader(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
if len(frame) != queHeaderSize+5 {
t.Fatalf("unexpected frame length: got %d want %d", len(frame), queHeaderSize+5)
}
if !bytes.Equal(frame[:queMagicSize], queMagic) {
t.Fatalf("unexpected magic: %v", frame[:queMagicSize])
}
if got := ByteToUint32(frame[queMagicSize : queMagicSize+4]); got != 5 {
t.Fatalf("unexpected payload length: got %d want 5", got)
}
if frame[12] != queVersionV1 {
t.Fatalf("unexpected version: got %d want %d", frame[12], queVersionV1)
}
if frame[13] != queSupportedFlags {
t.Fatalf("unexpected flags: got %d want %d", frame[13], queSupportedFlags)
}
if !bytes.Equal(frame[queHeaderSize:], []byte("hello")) {
t.Fatalf("unexpected payload: %q", frame[queHeaderSize:])
}
}
func TestQueueWriteFrameMatchesBuildMessage(t *testing.T) {
que := NewQueue()
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrame mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFrameBuffersMatchesBuildMessage(t *testing.T) {
que := NewQueue()
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrameBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrameBuffers mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFramesBuffersMatchesBuildMessage(t *testing.T) {
que := NewQueue()
payloads := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("batch"),
}
var want []byte
for _, payload := range payloads {
want = append(want, que.BuildMessage(payload)...)
}
var buf bytes.Buffer
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFramesBuffers mismatch: got %v want %v", got, want)
}
}
func TestQueueWriteFrameHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrame(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrame mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueWriteFrameBuffersHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
want := que.BuildMessage([]byte("hello"))
var buf bytes.Buffer
if err := que.WriteFrameBuffers(&buf, []byte("hello")); err != nil {
t.Fatalf("WriteFrameBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFrameBuffers mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueWriteFramesBuffersHonorsEncodeFunc(t *testing.T) {
que := NewQueue()
que.Encode = true
que.EncodeFunc = bytes.ToUpper
payloads := [][]byte{
[]byte("hello"),
[]byte("batch"),
}
var want []byte
for _, payload := range payloads {
want = append(want, que.BuildMessage(payload)...)
}
var buf bytes.Buffer
if err := que.WriteFramesBuffers(&buf, payloads...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
if got := buf.Bytes(); !bytes.Equal(got, want) {
t.Fatalf("WriteFramesBuffers mismatch with EncodeFunc: got %v want %v", got, want)
}
}
func TestQueueParseMessageSplitAcrossCalls(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
if err := que.ParseMessage(frame[:3], "split"); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessage(frame[3:11], "split"); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessage(frame[11:], "split"); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
select {
case data := <-que.RestoreChan():
if data.ID != 0 {
t.Fatalf("expected deprecated frame ID to stay zero, got %d", data.ID)
}
if data.Conn != "split" {
t.Fatalf("unexpected conn: %#v", data.Conn)
}
if !bytes.Equal(data.Msg, []byte("hello")) {
t.Fatalf("unexpected payload: %q", data.Msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("did not restore parsed frame")
}
}
func TestQueueParseMessageViewSplitAcrossCalls(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got [][]byte
handler := func(view FrameView) error {
got = append(got, cloneBytes(view.Payload))
if view.Conn != "split-view" {
t.Fatalf("unexpected conn: %#v", view.Conn)
}
return nil
}
if err := que.ParseMessageView(frame[:3], "split-view", handler); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessageView(frame[3:11], "split-view", handler); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessageView(frame[11:], "split-view", handler); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
if len(got) != 1 {
t.Fatalf("unexpected frame count: got %d want 1", len(got))
}
if !bytes.Equal(got[0], []byte("hello")) {
t.Fatalf("unexpected payload: %q", got[0])
}
}
func TestQueueParseMessageOwnedSplitAcrossCalls(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got []MsgQueue
handler := func(msg MsgQueue) error {
got = append(got, MsgQueue{
Msg: cloneBytes(msg.Msg),
Conn: msg.Conn,
})
return nil
}
if err := que.ParseMessageOwned(frame[:3], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on partial magic: %v", err)
}
if err := que.ParseMessageOwned(frame[3:11], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on partial header: %v", err)
}
if err := que.ParseMessageOwned(frame[11:], "split-owned", handler); err != nil {
t.Fatalf("unexpected error on payload completion: %v", err)
}
if len(got) != 1 {
t.Fatalf("unexpected frame count: got %d want 1", len(got))
}
if got[0].Conn != "split-owned" {
t.Fatalf("unexpected conn: %#v", got[0].Conn)
}
if !bytes.Equal(got[0].Msg, []byte("hello")) {
t.Fatalf("unexpected payload: %q", got[0].Msg)
}
select {
case msg := <-que.RestoreChan():
t.Fatalf("ParseMessageOwned should not use RestoreChan, got %#v", msg)
default:
}
}
func TestQueueParseMessageSkipsGarbagePrefix(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(append([]byte("junk"), frame...), "garbage")
if !errors.Is(err, ErrQueueDataFormat) {
t.Fatalf("expected data format error, got %v", err)
}
select {
case data := <-que.RestoreChan():
if !bytes.Equal(data.Msg, []byte("hello")) {
t.Fatalf("unexpected payload after resync: %q", data.Msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("did not restore frame after skipping garbage")
}
}
func TestQueueParseMessageViewSkipsGarbagePrefix(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
var got []byte
err := que.ParseMessageView(append([]byte("junk"), frame...), "garbage-view", func(view FrameView) error {
got = cloneBytes(view.Payload)
return nil
})
if !errors.Is(err, ErrQueueDataFormat) {
t.Fatalf("expected data format error, got %v", err)
}
if !bytes.Equal(got, []byte("hello")) {
t.Fatalf("unexpected payload after resync: %q", got)
}
}
func TestQueueParseMessageViewNilHandler(t *testing.T) {
que := NewQueue()
err := que.ParseMessageView(que.BuildMessage([]byte("hello")), "nil-handler", nil)
if !errors.Is(err, ErrQueueFrameHandlerNil) {
t.Fatalf("ParseMessageView error = %v, want %v", err, ErrQueueFrameHandlerNil)
}
}
func TestQueueParseMessageOwnedNilHandler(t *testing.T) {
que := NewQueue()
err := que.ParseMessageOwned(que.BuildMessage([]byte("hello")), "nil-handler-owned", nil)
if !errors.Is(err, ErrQueueFrameHandlerNil) {
t.Fatalf("ParseMessageOwned error = %v, want %v", err, ErrQueueFrameHandlerNil)
}
}
func TestQueueWriteFrameNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFrame(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFrame error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueWriteFrameBuffersNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFrameBuffers(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFrameBuffers error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueWriteFramesBuffersNilWriter(t *testing.T) {
que := NewQueue()
err := que.WriteFramesBuffers(nil, []byte("hello"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("WriteFramesBuffers error = %v, want %v", err, io.ErrClosedPipe)
}
}
func TestQueueParseMessageRejectsUnsupportedVersion(t *testing.T) {
que := NewQueueWithCount(1)
frame := que.BuildMessage([]byte("hello"))
frame[12] = 2
err := que.ParseMessage(frame, "version")
if !errors.Is(err, ErrQueueUnsupportedVersion) {
t.Fatalf("expected unsupported version error, got %v", err)
}
select {
case data := <-que.RestoreChan():
t.Fatalf("unexpected restored frame: %#v", data)
default:
}
}
func TestQueueParseMessageRejectsMessageTooLarge(t *testing.T) {
que := NewQueueCtx(nil, 1, 4)
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, "large")
if !errors.Is(err, ErrQueueMessageTooLarge) {
t.Fatalf("expected message too large error, got %v", err)
}
select {
case data := <-que.RestoreChan():
t.Fatalf("unexpected restored frame: %#v", data)
default:
}
}
func TestQueueParseMessageRejectsInvalidConnKey(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, []byte("not-comparable"))
if !errors.Is(err, ErrQueueConnKeyInvalid) {
t.Fatalf("expected invalid conn key error, got %v", err)
}
}
func TestQueueParseMessageRejectsNilConnKey(t *testing.T) {
que := NewQueue()
frame := que.BuildMessage([]byte("hello"))
err := que.ParseMessage(frame, nil)
if !errors.Is(err, ErrQueueConnKeyNil) {
t.Fatalf("expected nil conn key error, got %v", err)
}
}
func TestQueuePayloadSizeToUint32RejectsOverflow(t *testing.T) {
_, err := payloadSizeToUint32(uint64(^uint32(0)) + 1)
if !errors.Is(err, ErrQueueMessageTooLarge) {
t.Fatalf("expected message too large error, got %v", err)
}
}
func TestQueueRestoreDurationZeroWaitsUntilMessage(t *testing.T) {
que := NewQueueWithCount(1)
que.RestoreDuration(0)
type restoreResult struct {
msg MsgQueue
err error
}
resultCh := make(chan restoreResult, 1)
2024-08-18 17:17:20 +08:00
go func() {
msg, err := que.Restore()
resultCh <- restoreResult{msg: msg, err: err}
}()
select {
case result := <-resultCh:
t.Fatalf("Restore returned too early: %#v", result)
case <-time.After(50 * time.Millisecond):
}
if err := que.ParseMessage(que.BuildMessage([]byte("hello")), "forever"); err != nil {
t.Fatalf("ParseMessage failed: %v", err)
}
select {
case result := <-resultCh:
if result.err != nil {
t.Fatalf("Restore returned error: %v", result.err)
2024-08-18 17:17:20 +08:00
}
if result.msg.Conn != "forever" || !bytes.Equal(result.msg.Msg, []byte("hello")) {
t.Fatalf("unexpected restore result: %#v", result.msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Restore did not return after message arrival")
}
}
func TestQueueRestoreReturnsContextErrorOnStop(t *testing.T) {
que := NewQueue()
resultCh := make(chan error, 1)
go func() {
_, err := que.Restore()
resultCh <- err
2024-08-18 17:17:20 +08:00
}()
que.Stop()
select {
case err := <-resultCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
2024-08-18 17:17:20 +08:00
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Restore did not return after Stop")
}
}
func TestQueueRestoreChanClosesOnStop(t *testing.T) {
que := NewQueue()
resultCh := make(chan bool, 1)
go func() {
_, ok := <-que.RestoreChan()
resultCh <- ok
}()
que.Stop()
select {
case ok := <-resultCh:
if ok {
t.Fatal("expected RestoreChan to close after Stop")
}
case <-time.After(200 * time.Millisecond):
t.Fatal("RestoreChan did not close after Stop")
}
}
func TestQueueRestoreChanClosesOnContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
que := NewQueueCtx(ctx, 1, 0)
resultCh := make(chan bool, 1)
go func() {
_, ok := <-que.RestoreChan()
resultCh <- ok
}()
cancel()
select {
case ok := <-resultCh:
if ok {
t.Fatal("expected RestoreChan to close after context cancel")
}
case <-time.After(200 * time.Millisecond):
t.Fatal("RestoreChan did not close after context cancel")
}
}
func TestQueueParseMessageReturnsContextErrorWhenStoppedWhilePoolIsFull(t *testing.T) {
que := NewQueueWithCount(1)
if err := que.ParseMessage(que.BuildMessage([]byte("first")), "full"); err != nil {
t.Fatalf("ParseMessage first failed: %v", err)
}
errCh := make(chan error, 1)
go func() {
errCh <- que.ParseMessage(que.BuildMessage([]byte("second")), "full")
}()
select {
case err := <-errCh:
t.Fatalf("ParseMessage returned before Stop: %v", err)
case <-time.After(50 * time.Millisecond):
}
que.Stop()
select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("ParseMessage did not return after Stop")
}
}
func TestQueueParseMessageViewAllowsReentrantParseOnSameConn(t *testing.T) {
que := NewQueue()
reentered := false
err := que.ParseMessageView(que.BuildMessage([]byte("outer")), "reentrant", func(view FrameView) error {
if !bytes.Equal(view.Payload, []byte("outer")) {
t.Fatalf("unexpected outer payload: %q", view.Payload)
}
return que.ParseMessageView(que.BuildMessage([]byte("inner")), "reentrant", func(inner FrameView) error {
reentered = true
if !bytes.Equal(inner.Payload, []byte("inner")) {
t.Fatalf("unexpected inner payload: %q", inner.Payload)
}
return nil
})
})
if err != nil {
t.Fatalf("ParseMessageView failed: %v", err)
}
if !reentered {
t.Fatal("expected reentrant ParseMessageView to run")
2024-08-18 17:17:20 +08:00
}
}