package stario import ( "bytes" "errors" "io" "os" "testing" ) type chunkedReader struct { data []byte max int } func (reader *chunkedReader) Read(p []byte) (int, error) { if len(reader.data) == 0 { return 0, io.EOF } if reader.max > 0 && len(p) > reader.max { p = p[:reader.max] } n := copy(p, reader.data) reader.data = reader.data[n:] return n, nil } type trackingReader struct { data []byte maxRequested int } func (reader *trackingReader) Read(p []byte) (int, error) { if len(p) > reader.maxRequested { reader.maxRequested = len(p) } if len(reader.data) == 0 { return 0, io.EOF } n := copy(p, reader.data) reader.data = reader.data[n:] return n, nil } type stagedReaderStep struct { data []byte err error } type stagedReader struct { steps []stagedReaderStep index int } func (reader *stagedReader) Read(p []byte) (int, error) { if reader.index >= len(reader.steps) { return 0, io.EOF } step := &reader.steps[reader.index] n := copy(p, step.data) step.data = step.data[n:] if len(step.data) == 0 { err := step.err reader.index++ return n, err } return n, nil } func TestFrameWriterReaderRoundTrip(t *testing.T) { var wire bytes.Buffer writer := NewFrameWriter(&wire, nil) if err := writer.WriteFrame([]byte("one")); err != nil { t.Fatalf("WriteFrame failed: %v", err) } if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil { t.Fatalf("WriteFramesBuffers failed: %v", err) } reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil) cases := []string{"one", "two", "three"} for _, want := range cases { got, err := reader.Next() if err != nil { t.Fatalf("Next returned error: %v", err) } if string(got) != want { t.Fatalf("unexpected frame: got %q want %q", got, want) } } if _, err := reader.Next(); !errors.Is(err, io.EOF) { t.Fatalf("expected EOF after draining frames, got %v", err) } } func TestFrameReaderHandlesPartialTransportReads(t *testing.T) { var wire bytes.Buffer writer := NewFrameWriter(&wire, nil) if err := writer.WriteFrame([]byte("hello")); err != nil { t.Fatalf("WriteFrame failed: %v", err) } reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil) reader.SetReadBufferSize(3) got, err := reader.Next() if err != nil { t.Fatalf("Next returned error: %v", err) } if string(got) != "hello" { t.Fatalf("unexpected frame: %q", got) } } func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) { var wire bytes.Buffer writer := NewFrameWriter(&wire, nil) if err := writer.WriteFrame([]byte("hello")); err != nil { t.Fatalf("WriteFrame failed: %v", err) } source := &trackingReader{data: wire.Bytes()} reader := NewFrameReader(source, nil) reader.SetReadBufferSize(3) got, err := reader.Next() if err != nil { t.Fatalf("Next returned error: %v", err) } if string(got) != "hello" { t.Fatalf("unexpected frame: %q", got) } if got, want := source.maxRequested, 3; got != want { t.Fatalf("max requested read size = %d, want %d", got, want) } } func TestFrameReaderNextPooledAndNextView(t *testing.T) { queue := NewQueue() wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...) reader := NewFrameReader(bytes.NewReader(wire), queue) first, release, err := reader.NextPooled() if err != nil { t.Fatalf("NextPooled returned error: %v", err) } if string(first) != "alpha" { t.Fatalf("unexpected pooled frame: %q", first) } if release == nil { t.Fatal("NextPooled should return a release func") } release() var second string if err := reader.NextView(func(view FrameView) error { second = string(view.Payload) return nil }); err != nil { t.Fatalf("NextView returned error: %v", err) } if second != "beta" { t.Fatalf("unexpected view frame: %q", second) } } func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) { queue := NewQueue() frame := queue.BuildMessage([]byte("hello")) reader := NewFrameReader(&stagedReader{ steps: []stagedReaderStep{ {data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded}, {data: append([]byte(nil), frame[5:]...), err: nil}, }, }, queue) if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("expected deadline exceeded on partial frame, got %v", err) } got, release, err := reader.NextPooled() if err != nil { t.Fatalf("NextPooled after deadline returned error: %v", err) } if string(got) != "hello" { t.Fatalf("unexpected frame after deadline: %q", got) } if release != nil { release() } } func TestFrameWriterNilWriterFails(t *testing.T) { writer := NewFrameWriter(nil, nil) if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) { t.Fatalf("expected io.ErrClosedPipe, got %v", err) } } func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) { queue := NewQueue() frame := queue.BuildMessage([]byte("hello")) reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil) if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) { t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err) } } func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) { queue := NewQueue() first := queue.BuildMessage([]byte("one")) second := queue.BuildMessage([]byte("two")) wire := append(append([]byte{}, first...), second[:len(second)-1]...) reader := NewFrameReader(bytes.NewReader(wire), nil) got, err := reader.Next() if err != nil { t.Fatalf("unexpected error on first frame: %v", err) } if string(got) != "one" { t.Fatalf("unexpected first frame: %q", got) } if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) { t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err) } }