diff --git a/frameio.go b/frameio.go index d1e1ea8..8586697 100644 --- a/frameio.go +++ b/frameio.go @@ -3,12 +3,15 @@ package stario import ( "errors" "io" + "sync" ) // DefaultFrameReaderBufferSize is the default transport read chunk size used by // FrameReader. const DefaultFrameReaderBufferSize = 32 * 1024 +var framePayloadPool sync.Pool + type frameReaderConnKey struct{} // FrameWriter adapts StarQueue framing helpers to an io.Writer. @@ -56,12 +59,16 @@ func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error { // FrameReader adapts StarQueue parsing helpers to an io.Reader. type FrameReader struct { - reader io.Reader - queue *StarQueue - connKey interface{} - readSize int - pending [][]byte - pendingErr error + reader io.Reader + queue *StarQueue + connKey interface{} + readSize int + + header [queHeaderSize]byte + headerRead int + payload []byte + payloadRead int + release func() } // NewFrameReader creates a framing reader backed by queue. When queue is nil, a @@ -107,44 +114,55 @@ func (reader *FrameReader) Next() ([]byte, error) { if reader == nil || reader.reader == nil || reader.queue == nil { return nil, io.ErrClosedPipe } - if len(reader.pending) > 0 { - return reader.popPending(), nil - } - if reader.pendingErr != nil { - err := reader.pendingErr - reader.pendingErr = nil + payload, release, err := reader.NextPooled() + if err != nil { return nil, err } - if reader.readSize <= 0 { - reader.readSize = DefaultFrameReaderBufferSize - } - buf := make([]byte, reader.readSize) - for { - n, readErr := reader.reader.Read(buf) - if n > 0 { - parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error { - reader.pending = append(reader.pending, msg.Msg) - return nil - }) - err := reader.normalizeReadErr(parseErr, readErr) - if len(reader.pending) > 0 { - reader.pendingErr = err - return reader.popPending(), nil - } - if err != nil { - return nil, err - } - } - if readErr != nil { - return nil, reader.normalizeReadErr(nil, readErr) - } + if release == nil { + return payload, nil } + owned := cloneBytes(payload) + release() + return owned, nil } -func (reader *FrameReader) popPending() []byte { - next := reader.pending[0] - reader.pending = reader.pending[1:] - return next +// NextPooled returns the next frame payload. The caller must call release when +// it is non-nil. +func (reader *FrameReader) NextPooled() ([]byte, func(), error) { + if reader == nil || reader.reader == nil || reader.queue == nil { + return nil, nil, io.ErrClosedPipe + } + if err := reader.readHeader(); err != nil { + return nil, nil, err + } + if err := reader.ensurePayloadBuffer(true); err != nil { + return nil, nil, err + } + if len(reader.payload) > 0 { + if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil { + return nil, nil, err + } + } + return reader.finishFrame() +} + +// NextView reads the next frame and exposes its payload only for the duration +// of fn. +func (reader *FrameReader) NextView(fn func(FrameView) error) error { + if fn == nil { + return ErrQueueFrameHandlerNil + } + payload, release, err := reader.NextPooled() + if err != nil { + return err + } + if release != nil { + defer release() + } + return fn(FrameView{ + Payload: payload, + Conn: reader.connKey, + }) } func joinFrameReaderError(parseErr error, readErr error) error { @@ -167,25 +185,121 @@ func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error } func (reader *FrameReader) hasBufferedState() bool { - if reader == nil || reader.queue == nil { + if reader == nil { return false } - stateAny, ok := reader.queue.states.Load(reader.connKey) - if !ok { - return false - } - state, ok := stateAny.(*queConnState) - if !ok { - return false - } - state.mu.Lock() - defer state.mu.Unlock() - return len(state.buf) > 0 + return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0 } func (reader *FrameReader) clearBufferedState() { - if reader == nil || reader.queue == nil { + if reader == nil { return } - reader.queue.states.Delete(reader.connKey) + reader.headerRead = 0 + reader.payloadRead = 0 + if reader.release != nil { + reader.release() + } + reader.release = nil + reader.payload = nil +} + +func (reader *FrameReader) readHeader() error { + if reader.headerRead == queHeaderSize { + return nil + } + return reader.readInto(reader.header[:], &reader.headerRead) +} + +func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error { + if reader.payload != nil || reader.headerRead < queHeaderSize { + return nil + } + header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength) + if err != nil { + reader.clearBufferedState() + return err + } + if header.Length == 0 { + reader.payload = []byte{} + reader.release = nil + return nil + } + payloadLen := int(header.Length) + switch { + case reader.queue.Encode && reader.queue.DecodeFunc != nil: + reader.payload = make([]byte, payloadLen) + case pooled: + buf := getFramePayloadBuffer(payloadLen) + reader.payload = buf + reader.release = func() { + putFramePayloadBuffer(buf) + } + default: + reader.payload = make([]byte, payloadLen) + } + return nil +} + +func (reader *FrameReader) finishFrame() ([]byte, func(), error) { + payload := reader.payload + release := reader.release + reader.payload = nil + reader.payloadRead = 0 + reader.release = nil + reader.headerRead = 0 + if reader.queue.Encode && reader.queue.DecodeFunc != nil { + decoded := reader.queue.DecodeFunc(payload) + if release != nil { + release() + release = nil + } + payload = decoded + } + return payload, release, nil +} + +func (reader *FrameReader) readInto(dst []byte, offset *int) error { + for *offset < len(dst) { + maxRead := len(dst) - *offset + if reader.readSize > 0 && maxRead > reader.readSize { + maxRead = reader.readSize + } + n, err := reader.reader.Read(dst[*offset : *offset+maxRead]) + if n > 0 { + *offset += n + } + if err != nil { + if errors.Is(err, io.EOF) { + if *offset == 0 { + return io.EOF + } + if *offset < len(dst) { + return io.ErrUnexpectedEOF + } + } + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} + +func getFramePayloadBuffer(size int) []byte { + if size <= 0 { + return nil + } + if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size { + return pooled[:size] + } + return make([]byte, size) +} + +func putFramePayloadBuffer(buf []byte) { + if cap(buf) == 0 || cap(buf) > 32*1024*1024 { + return + } + framePayloadPool.Put(buf[:0]) } diff --git a/frameio_test.go b/frameio_test.go index e08a0f0..ed02cd1 100644 --- a/frameio_test.go +++ b/frameio_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "os" "testing" ) @@ -24,6 +25,48 @@ func (reader *chunkedReader) Read(p []byte) (int, error) { return n, nil } +type trackingReader struct { + data []byte + maxRequested int +} + +func (reader *trackingReader) Read(p []byte) (int, error) { + if len(p) > reader.maxRequested { + reader.maxRequested = len(p) + } + if len(reader.data) == 0 { + return 0, io.EOF + } + n := copy(p, reader.data) + reader.data = reader.data[n:] + return n, nil +} + +type stagedReaderStep struct { + data []byte + err error +} + +type stagedReader struct { + steps []stagedReaderStep + index int +} + +func (reader *stagedReader) Read(p []byte) (int, error) { + if reader.index >= len(reader.steps) { + return 0, io.EOF + } + step := &reader.steps[reader.index] + n := copy(p, step.data) + step.data = step.data[n:] + if len(step.data) == 0 { + err := step.err + reader.index++ + return n, err + } + return n, nil +} + func TestFrameWriterReaderRoundTrip(t *testing.T) { var wire bytes.Buffer writer := NewFrameWriter(&wire, nil) @@ -68,6 +111,83 @@ func TestFrameReaderHandlesPartialTransportReads(t *testing.T) { } } +func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) { + var wire bytes.Buffer + writer := NewFrameWriter(&wire, nil) + if err := writer.WriteFrame([]byte("hello")); err != nil { + t.Fatalf("WriteFrame failed: %v", err) + } + + source := &trackingReader{data: wire.Bytes()} + reader := NewFrameReader(source, nil) + reader.SetReadBufferSize(3) + + got, err := reader.Next() + if err != nil { + t.Fatalf("Next returned error: %v", err) + } + if string(got) != "hello" { + t.Fatalf("unexpected frame: %q", got) + } + if got, want := source.maxRequested, 3; got != want { + t.Fatalf("max requested read size = %d, want %d", got, want) + } +} + +func TestFrameReaderNextPooledAndNextView(t *testing.T) { + queue := NewQueue() + wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...) + reader := NewFrameReader(bytes.NewReader(wire), queue) + + first, release, err := reader.NextPooled() + if err != nil { + t.Fatalf("NextPooled returned error: %v", err) + } + if string(first) != "alpha" { + t.Fatalf("unexpected pooled frame: %q", first) + } + if release == nil { + t.Fatal("NextPooled should return a release func") + } + release() + + var second string + if err := reader.NextView(func(view FrameView) error { + second = string(view.Payload) + return nil + }); err != nil { + t.Fatalf("NextView returned error: %v", err) + } + if second != "beta" { + t.Fatalf("unexpected view frame: %q", second) + } +} + +func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) { + queue := NewQueue() + frame := queue.BuildMessage([]byte("hello")) + reader := NewFrameReader(&stagedReader{ + steps: []stagedReaderStep{ + {data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded}, + {data: append([]byte(nil), frame[5:]...), err: nil}, + }, + }, queue) + + if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("expected deadline exceeded on partial frame, got %v", err) + } + got, release, err := reader.NextPooled() + if err != nil { + t.Fatalf("NextPooled after deadline returned error: %v", err) + } + if string(got) != "hello" { + t.Fatalf("unexpected frame after deadline: %q", got) + } + if release != nil { + release() + } +} + func TestFrameWriterNilWriterFails(t *testing.T) { writer := NewFrameWriter(nil, nil) if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) { diff --git a/go.mod b/go.mod index fadce3f..34aceab 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module b612.me/stario -go 1.20 +go 1.18 require golang.org/x/term v0.23.0