Compare commits

..

1 Commits

Author SHA1 Message Date
348ff7418b
fix(frameio): honor configured read buffer size
- make FrameReader readInto respect SetReadBufferSize
  - add regression test to verify underlying read chunk size
  - keep default FrameReader buffer conservative for generic stario use
  - declare module go version as 1.18
2026-04-18 15:58:29 +08:00
3 changed files with 288 additions and 54 deletions

View File

@ -3,12 +3,15 @@ package stario
import ( import (
"errors" "errors"
"io" "io"
"sync"
) )
// DefaultFrameReaderBufferSize is the default transport read chunk size used by // DefaultFrameReaderBufferSize is the default transport read chunk size used by
// FrameReader. // FrameReader.
const DefaultFrameReaderBufferSize = 32 * 1024 const DefaultFrameReaderBufferSize = 32 * 1024
var framePayloadPool sync.Pool
type frameReaderConnKey struct{} type frameReaderConnKey struct{}
// FrameWriter adapts StarQueue framing helpers to an io.Writer. // FrameWriter adapts StarQueue framing helpers to an io.Writer.
@ -56,12 +59,16 @@ func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
// FrameReader adapts StarQueue parsing helpers to an io.Reader. // FrameReader adapts StarQueue parsing helpers to an io.Reader.
type FrameReader struct { type FrameReader struct {
reader io.Reader reader io.Reader
queue *StarQueue queue *StarQueue
connKey interface{} connKey interface{}
readSize int readSize int
pending [][]byte
pendingErr error header [queHeaderSize]byte
headerRead int
payload []byte
payloadRead int
release func()
} }
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a // NewFrameReader creates a framing reader backed by queue. When queue is nil, a
@ -107,44 +114,55 @@ func (reader *FrameReader) Next() ([]byte, error) {
if reader == nil || reader.reader == nil || reader.queue == nil { if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
} }
if len(reader.pending) > 0 { payload, release, err := reader.NextPooled()
return reader.popPending(), nil if err != nil {
}
if reader.pendingErr != nil {
err := reader.pendingErr
reader.pendingErr = nil
return nil, err return nil, err
} }
if reader.readSize <= 0 { if release == nil {
reader.readSize = DefaultFrameReaderBufferSize return payload, nil
}
buf := make([]byte, reader.readSize)
for {
n, readErr := reader.reader.Read(buf)
if n > 0 {
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
reader.pending = append(reader.pending, msg.Msg)
return nil
})
err := reader.normalizeReadErr(parseErr, readErr)
if len(reader.pending) > 0 {
reader.pendingErr = err
return reader.popPending(), nil
}
if err != nil {
return nil, err
}
}
if readErr != nil {
return nil, reader.normalizeReadErr(nil, readErr)
}
} }
owned := cloneBytes(payload)
release()
return owned, nil
} }
func (reader *FrameReader) popPending() []byte { // NextPooled returns the next frame payload. The caller must call release when
next := reader.pending[0] // it is non-nil.
reader.pending = reader.pending[1:] func (reader *FrameReader) NextPooled() ([]byte, func(), error) {
return next if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, nil, io.ErrClosedPipe
}
if err := reader.readHeader(); err != nil {
return nil, nil, err
}
if err := reader.ensurePayloadBuffer(true); err != nil {
return nil, nil, err
}
if len(reader.payload) > 0 {
if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil {
return nil, nil, err
}
}
return reader.finishFrame()
}
// NextView reads the next frame and exposes its payload only for the duration
// of fn.
func (reader *FrameReader) NextView(fn func(FrameView) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
payload, release, err := reader.NextPooled()
if err != nil {
return err
}
if release != nil {
defer release()
}
return fn(FrameView{
Payload: payload,
Conn: reader.connKey,
})
} }
func joinFrameReaderError(parseErr error, readErr error) error { func joinFrameReaderError(parseErr error, readErr error) error {
@ -167,25 +185,121 @@ func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error
} }
func (reader *FrameReader) hasBufferedState() bool { func (reader *FrameReader) hasBufferedState() bool {
if reader == nil || reader.queue == nil { if reader == nil {
return false return false
} }
stateAny, ok := reader.queue.states.Load(reader.connKey) return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0
if !ok {
return false
}
state, ok := stateAny.(*queConnState)
if !ok {
return false
}
state.mu.Lock()
defer state.mu.Unlock()
return len(state.buf) > 0
} }
func (reader *FrameReader) clearBufferedState() { func (reader *FrameReader) clearBufferedState() {
if reader == nil || reader.queue == nil { if reader == nil {
return return
} }
reader.queue.states.Delete(reader.connKey) reader.headerRead = 0
reader.payloadRead = 0
if reader.release != nil {
reader.release()
}
reader.release = nil
reader.payload = nil
}
func (reader *FrameReader) readHeader() error {
if reader.headerRead == queHeaderSize {
return nil
}
return reader.readInto(reader.header[:], &reader.headerRead)
}
func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error {
if reader.payload != nil || reader.headerRead < queHeaderSize {
return nil
}
header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength)
if err != nil {
reader.clearBufferedState()
return err
}
if header.Length == 0 {
reader.payload = []byte{}
reader.release = nil
return nil
}
payloadLen := int(header.Length)
switch {
case reader.queue.Encode && reader.queue.DecodeFunc != nil:
reader.payload = make([]byte, payloadLen)
case pooled:
buf := getFramePayloadBuffer(payloadLen)
reader.payload = buf
reader.release = func() {
putFramePayloadBuffer(buf)
}
default:
reader.payload = make([]byte, payloadLen)
}
return nil
}
func (reader *FrameReader) finishFrame() ([]byte, func(), error) {
payload := reader.payload
release := reader.release
reader.payload = nil
reader.payloadRead = 0
reader.release = nil
reader.headerRead = 0
if reader.queue.Encode && reader.queue.DecodeFunc != nil {
decoded := reader.queue.DecodeFunc(payload)
if release != nil {
release()
release = nil
}
payload = decoded
}
return payload, release, nil
}
func (reader *FrameReader) readInto(dst []byte, offset *int) error {
for *offset < len(dst) {
maxRead := len(dst) - *offset
if reader.readSize > 0 && maxRead > reader.readSize {
maxRead = reader.readSize
}
n, err := reader.reader.Read(dst[*offset : *offset+maxRead])
if n > 0 {
*offset += n
}
if err != nil {
if errors.Is(err, io.EOF) {
if *offset == 0 {
return io.EOF
}
if *offset < len(dst) {
return io.ErrUnexpectedEOF
}
}
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
func getFramePayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putFramePayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
framePayloadPool.Put(buf[:0])
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"os"
"testing" "testing"
) )
@ -24,6 +25,48 @@ func (reader *chunkedReader) Read(p []byte) (int, error) {
return n, nil return n, nil
} }
type trackingReader struct {
data []byte
maxRequested int
}
func (reader *trackingReader) Read(p []byte) (int, error) {
if len(p) > reader.maxRequested {
reader.maxRequested = len(p)
}
if len(reader.data) == 0 {
return 0, io.EOF
}
n := copy(p, reader.data)
reader.data = reader.data[n:]
return n, nil
}
type stagedReaderStep struct {
data []byte
err error
}
type stagedReader struct {
steps []stagedReaderStep
index int
}
func (reader *stagedReader) Read(p []byte) (int, error) {
if reader.index >= len(reader.steps) {
return 0, io.EOF
}
step := &reader.steps[reader.index]
n := copy(p, step.data)
step.data = step.data[n:]
if len(step.data) == 0 {
err := step.err
reader.index++
return n, err
}
return n, nil
}
func TestFrameWriterReaderRoundTrip(t *testing.T) { func TestFrameWriterReaderRoundTrip(t *testing.T) {
var wire bytes.Buffer var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil) writer := NewFrameWriter(&wire, nil)
@ -68,6 +111,83 @@ func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
} }
} }
func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
source := &trackingReader{data: wire.Bytes()}
reader := NewFrameReader(source, nil)
reader.SetReadBufferSize(3)
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame: %q", got)
}
if got, want := source.maxRequested, 3; got != want {
t.Fatalf("max requested read size = %d, want %d", got, want)
}
}
func TestFrameReaderNextPooledAndNextView(t *testing.T) {
queue := NewQueue()
wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...)
reader := NewFrameReader(bytes.NewReader(wire), queue)
first, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled returned error: %v", err)
}
if string(first) != "alpha" {
t.Fatalf("unexpected pooled frame: %q", first)
}
if release == nil {
t.Fatal("NextPooled should return a release func")
}
release()
var second string
if err := reader.NextView(func(view FrameView) error {
second = string(view.Payload)
return nil
}); err != nil {
t.Fatalf("NextView returned error: %v", err)
}
if second != "beta" {
t.Fatalf("unexpected view frame: %q", second)
}
}
func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) {
queue := NewQueue()
frame := queue.BuildMessage([]byte("hello"))
reader := NewFrameReader(&stagedReader{
steps: []stagedReaderStep{
{data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded},
{data: append([]byte(nil), frame[5:]...), err: nil},
},
}, queue)
if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("expected deadline exceeded on partial frame, got %v", err)
}
got, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled after deadline returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame after deadline: %q", got)
}
if release != nil {
release()
}
}
func TestFrameWriterNilWriterFails(t *testing.T) { func TestFrameWriterNilWriterFails(t *testing.T) {
writer := NewFrameWriter(nil, nil) writer := NewFrameWriter(nil, nil)
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) { if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {

2
go.mod
View File

@ -1,6 +1,6 @@
module b612.me/stario module b612.me/stario
go 1.20 go 1.18
require golang.org/x/term v0.23.0 require golang.org/x/term v0.23.0