stario/frameio_test.go

226 lines
5.7 KiB
Go
Raw Permalink Normal View History

package stario
import (
"bytes"
"errors"
"io"
"os"
"testing"
)
type chunkedReader struct {
data []byte
max int
}
func (reader *chunkedReader) Read(p []byte) (int, error) {
if len(reader.data) == 0 {
return 0, io.EOF
}
if reader.max > 0 && len(p) > reader.max {
p = p[:reader.max]
}
n := copy(p, reader.data)
reader.data = reader.data[n:]
return n, nil
}
type trackingReader struct {
data []byte
maxRequested int
}
func (reader *trackingReader) Read(p []byte) (int, error) {
if len(p) > reader.maxRequested {
reader.maxRequested = len(p)
}
if len(reader.data) == 0 {
return 0, io.EOF
}
n := copy(p, reader.data)
reader.data = reader.data[n:]
return n, nil
}
type stagedReaderStep struct {
data []byte
err error
}
type stagedReader struct {
steps []stagedReaderStep
index int
}
func (reader *stagedReader) Read(p []byte) (int, error) {
if reader.index >= len(reader.steps) {
return 0, io.EOF
}
step := &reader.steps[reader.index]
n := copy(p, step.data)
step.data = step.data[n:]
if len(step.data) == 0 {
err := step.err
reader.index++
return n, err
}
return n, nil
}
func TestFrameWriterReaderRoundTrip(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("one")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
if err := writer.WriteFramesBuffers([][]byte{[]byte("two"), []byte("three")}...); err != nil {
t.Fatalf("WriteFramesBuffers failed: %v", err)
}
reader := NewFrameReader(bytes.NewReader(wire.Bytes()), nil)
cases := []string{"one", "two", "three"}
for _, want := range cases {
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != want {
t.Fatalf("unexpected frame: got %q want %q", got, want)
}
}
if _, err := reader.Next(); !errors.Is(err, io.EOF) {
t.Fatalf("expected EOF after draining frames, got %v", err)
}
}
func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
reader := NewFrameReader(&chunkedReader{data: wire.Bytes(), max: 3}, nil)
reader.SetReadBufferSize(3)
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame: %q", got)
}
}
func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
source := &trackingReader{data: wire.Bytes()}
reader := NewFrameReader(source, nil)
reader.SetReadBufferSize(3)
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame: %q", got)
}
if got, want := source.maxRequested, 3; got != want {
t.Fatalf("max requested read size = %d, want %d", got, want)
}
}
func TestFrameReaderNextPooledAndNextView(t *testing.T) {
queue := NewQueue()
wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...)
reader := NewFrameReader(bytes.NewReader(wire), queue)
first, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled returned error: %v", err)
}
if string(first) != "alpha" {
t.Fatalf("unexpected pooled frame: %q", first)
}
if release == nil {
t.Fatal("NextPooled should return a release func")
}
release()
var second string
if err := reader.NextView(func(view FrameView) error {
second = string(view.Payload)
return nil
}); err != nil {
t.Fatalf("NextView returned error: %v", err)
}
if second != "beta" {
t.Fatalf("unexpected view frame: %q", second)
}
}
func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) {
queue := NewQueue()
frame := queue.BuildMessage([]byte("hello"))
reader := NewFrameReader(&stagedReader{
steps: []stagedReaderStep{
{data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded},
{data: append([]byte(nil), frame[5:]...), err: nil},
},
}, queue)
if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("expected deadline exceeded on partial frame, got %v", err)
}
got, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled after deadline returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame after deadline: %q", got)
}
if release != nil {
release()
}
}
func TestFrameWriterNilWriterFails(t *testing.T) {
writer := NewFrameWriter(nil, nil)
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
}
}
func TestFrameReaderTruncatedFrameReturnsUnexpectedEOF(t *testing.T) {
queue := NewQueue()
frame := queue.BuildMessage([]byte("hello"))
reader := NewFrameReader(bytes.NewReader(frame[:len(frame)-1]), nil)
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected io.ErrUnexpectedEOF, got %v", err)
}
}
func TestFrameReaderReturnsUnexpectedEOFAfterPendingFrames(t *testing.T) {
queue := NewQueue()
first := queue.BuildMessage([]byte("one"))
second := queue.BuildMessage([]byte("two"))
wire := append(append([]byte{}, first...), second[:len(second)-1]...)
reader := NewFrameReader(bytes.NewReader(wire), nil)
got, err := reader.Next()
if err != nil {
t.Fatalf("unexpected error on first frame: %v", err)
}
if string(got) != "one" {
t.Fatalf("unexpected first frame: %q", got)
}
if _, err := reader.Next(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected io.ErrUnexpectedEOF after draining pending frame, got %v", err)
}
}