Compare commits

..

No commits in common. "master" and "v0.1.0" have entirely different histories.

3 changed files with 52 additions and 286 deletions

View File

@ -3,15 +3,12 @@ package stario
import ( import (
"errors" "errors"
"io" "io"
"sync"
) )
// DefaultFrameReaderBufferSize is the default transport read chunk size used by // DefaultFrameReaderBufferSize is the default transport read chunk size used by
// FrameReader. // FrameReader.
const DefaultFrameReaderBufferSize = 32 * 1024 const DefaultFrameReaderBufferSize = 32 * 1024
var framePayloadPool sync.Pool
type frameReaderConnKey struct{} type frameReaderConnKey struct{}
// FrameWriter adapts StarQueue framing helpers to an io.Writer. // FrameWriter adapts StarQueue framing helpers to an io.Writer.
@ -59,16 +56,12 @@ func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
// FrameReader adapts StarQueue parsing helpers to an io.Reader. // FrameReader adapts StarQueue parsing helpers to an io.Reader.
type FrameReader struct { type FrameReader struct {
reader io.Reader reader io.Reader
queue *StarQueue queue *StarQueue
connKey interface{} connKey interface{}
readSize int readSize int
pending [][]byte
header [queHeaderSize]byte pendingErr error
headerRead int
payload []byte
payloadRead int
release func()
} }
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a // NewFrameReader creates a framing reader backed by queue. When queue is nil, a
@ -114,55 +107,44 @@ func (reader *FrameReader) Next() ([]byte, error) {
if reader == nil || reader.reader == nil || reader.queue == nil { if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
} }
payload, release, err := reader.NextPooled() if len(reader.pending) > 0 {
if err != nil { return reader.popPending(), nil
}
if reader.pendingErr != nil {
err := reader.pendingErr
reader.pendingErr = nil
return nil, err return nil, err
} }
if release == nil { if reader.readSize <= 0 {
return payload, nil reader.readSize = DefaultFrameReaderBufferSize
} }
owned := cloneBytes(payload) buf := make([]byte, reader.readSize)
release() for {
return owned, nil n, readErr := reader.reader.Read(buf)
} if n > 0 {
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
// NextPooled returns the next frame payload. The caller must call release when reader.pending = append(reader.pending, msg.Msg)
// it is non-nil. return nil
func (reader *FrameReader) NextPooled() ([]byte, func(), error) { })
if reader == nil || reader.reader == nil || reader.queue == nil { err := reader.normalizeReadErr(parseErr, readErr)
return nil, nil, io.ErrClosedPipe if len(reader.pending) > 0 {
} reader.pendingErr = err
if err := reader.readHeader(); err != nil { return reader.popPending(), nil
return nil, nil, err }
} if err != nil {
if err := reader.ensurePayloadBuffer(true); err != nil { return nil, err
return nil, nil, err }
} }
if len(reader.payload) > 0 { if readErr != nil {
if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil { return nil, reader.normalizeReadErr(nil, readErr)
return nil, nil, err
} }
} }
return reader.finishFrame()
} }
// NextView reads the next frame and exposes its payload only for the duration func (reader *FrameReader) popPending() []byte {
// of fn. next := reader.pending[0]
func (reader *FrameReader) NextView(fn func(FrameView) error) error { reader.pending = reader.pending[1:]
if fn == nil { return next
return ErrQueueFrameHandlerNil
}
payload, release, err := reader.NextPooled()
if err != nil {
return err
}
if release != nil {
defer release()
}
return fn(FrameView{
Payload: payload,
Conn: reader.connKey,
})
} }
func joinFrameReaderError(parseErr error, readErr error) error { func joinFrameReaderError(parseErr error, readErr error) error {
@ -185,121 +167,25 @@ func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error
} }
func (reader *FrameReader) hasBufferedState() bool { func (reader *FrameReader) hasBufferedState() bool {
if reader == nil { if reader == nil || reader.queue == nil {
return false return false
} }
return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0 stateAny, ok := reader.queue.states.Load(reader.connKey)
if !ok {
return false
}
state, ok := stateAny.(*queConnState)
if !ok {
return false
}
state.mu.Lock()
defer state.mu.Unlock()
return len(state.buf) > 0
} }
func (reader *FrameReader) clearBufferedState() { func (reader *FrameReader) clearBufferedState() {
if reader == nil { if reader == nil || reader.queue == nil {
return return
} }
reader.headerRead = 0 reader.queue.states.Delete(reader.connKey)
reader.payloadRead = 0
if reader.release != nil {
reader.release()
}
reader.release = nil
reader.payload = nil
}
func (reader *FrameReader) readHeader() error {
if reader.headerRead == queHeaderSize {
return nil
}
return reader.readInto(reader.header[:], &reader.headerRead)
}
func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error {
if reader.payload != nil || reader.headerRead < queHeaderSize {
return nil
}
header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength)
if err != nil {
reader.clearBufferedState()
return err
}
if header.Length == 0 {
reader.payload = []byte{}
reader.release = nil
return nil
}
payloadLen := int(header.Length)
switch {
case reader.queue.Encode && reader.queue.DecodeFunc != nil:
reader.payload = make([]byte, payloadLen)
case pooled:
buf := getFramePayloadBuffer(payloadLen)
reader.payload = buf
reader.release = func() {
putFramePayloadBuffer(buf)
}
default:
reader.payload = make([]byte, payloadLen)
}
return nil
}
func (reader *FrameReader) finishFrame() ([]byte, func(), error) {
payload := reader.payload
release := reader.release
reader.payload = nil
reader.payloadRead = 0
reader.release = nil
reader.headerRead = 0
if reader.queue.Encode && reader.queue.DecodeFunc != nil {
decoded := reader.queue.DecodeFunc(payload)
if release != nil {
release()
release = nil
}
payload = decoded
}
return payload, release, nil
}
func (reader *FrameReader) readInto(dst []byte, offset *int) error {
for *offset < len(dst) {
maxRead := len(dst) - *offset
if reader.readSize > 0 && maxRead > reader.readSize {
maxRead = reader.readSize
}
n, err := reader.reader.Read(dst[*offset : *offset+maxRead])
if n > 0 {
*offset += n
}
if err != nil {
if errors.Is(err, io.EOF) {
if *offset == 0 {
return io.EOF
}
if *offset < len(dst) {
return io.ErrUnexpectedEOF
}
}
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
func getFramePayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putFramePayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
framePayloadPool.Put(buf[:0])
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"os"
"testing" "testing"
) )
@ -25,48 +24,6 @@ func (reader *chunkedReader) Read(p []byte) (int, error) {
return n, nil return n, nil
} }
type trackingReader struct {
data []byte
maxRequested int
}
func (reader *trackingReader) Read(p []byte) (int, error) {
if len(p) > reader.maxRequested {
reader.maxRequested = len(p)
}
if len(reader.data) == 0 {
return 0, io.EOF
}
n := copy(p, reader.data)
reader.data = reader.data[n:]
return n, nil
}
type stagedReaderStep struct {
data []byte
err error
}
type stagedReader struct {
steps []stagedReaderStep
index int
}
func (reader *stagedReader) Read(p []byte) (int, error) {
if reader.index >= len(reader.steps) {
return 0, io.EOF
}
step := &reader.steps[reader.index]
n := copy(p, step.data)
step.data = step.data[n:]
if len(step.data) == 0 {
err := step.err
reader.index++
return n, err
}
return n, nil
}
func TestFrameWriterReaderRoundTrip(t *testing.T) { func TestFrameWriterReaderRoundTrip(t *testing.T) {
var wire bytes.Buffer var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil) writer := NewFrameWriter(&wire, nil)
@ -111,83 +68,6 @@ func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
} }
} }
func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) {
var wire bytes.Buffer
writer := NewFrameWriter(&wire, nil)
if err := writer.WriteFrame([]byte("hello")); err != nil {
t.Fatalf("WriteFrame failed: %v", err)
}
source := &trackingReader{data: wire.Bytes()}
reader := NewFrameReader(source, nil)
reader.SetReadBufferSize(3)
got, err := reader.Next()
if err != nil {
t.Fatalf("Next returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame: %q", got)
}
if got, want := source.maxRequested, 3; got != want {
t.Fatalf("max requested read size = %d, want %d", got, want)
}
}
func TestFrameReaderNextPooledAndNextView(t *testing.T) {
queue := NewQueue()
wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...)
reader := NewFrameReader(bytes.NewReader(wire), queue)
first, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled returned error: %v", err)
}
if string(first) != "alpha" {
t.Fatalf("unexpected pooled frame: %q", first)
}
if release == nil {
t.Fatal("NextPooled should return a release func")
}
release()
var second string
if err := reader.NextView(func(view FrameView) error {
second = string(view.Payload)
return nil
}); err != nil {
t.Fatalf("NextView returned error: %v", err)
}
if second != "beta" {
t.Fatalf("unexpected view frame: %q", second)
}
}
func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) {
queue := NewQueue()
frame := queue.BuildMessage([]byte("hello"))
reader := NewFrameReader(&stagedReader{
steps: []stagedReaderStep{
{data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded},
{data: append([]byte(nil), frame[5:]...), err: nil},
},
}, queue)
if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("expected deadline exceeded on partial frame, got %v", err)
}
got, release, err := reader.NextPooled()
if err != nil {
t.Fatalf("NextPooled after deadline returned error: %v", err)
}
if string(got) != "hello" {
t.Fatalf("unexpected frame after deadline: %q", got)
}
if release != nil {
release()
}
}
func TestFrameWriterNilWriterFails(t *testing.T) { func TestFrameWriterNilWriterFails(t *testing.T) {
writer := NewFrameWriter(nil, nil) writer := NewFrameWriter(nil, nil)
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) { if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {

2
go.mod
View File

@ -1,6 +1,6 @@
module b612.me/stario module b612.me/stario
go 1.18 go 1.20
require golang.org/x/term v0.23.0 require golang.org/x/term v0.23.0