Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 348ff7418b |
220
frameio.go
220
frameio.go
@ -3,12 +3,15 @@ package stario
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
|
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
|
||||||
// FrameReader.
|
// FrameReader.
|
||||||
const DefaultFrameReaderBufferSize = 32 * 1024
|
const DefaultFrameReaderBufferSize = 32 * 1024
|
||||||
|
|
||||||
|
var framePayloadPool sync.Pool
|
||||||
|
|
||||||
type frameReaderConnKey struct{}
|
type frameReaderConnKey struct{}
|
||||||
|
|
||||||
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
|
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
|
||||||
@ -56,12 +59,16 @@ func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
|
|||||||
|
|
||||||
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
|
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
|
||||||
type FrameReader struct {
|
type FrameReader struct {
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
queue *StarQueue
|
queue *StarQueue
|
||||||
connKey interface{}
|
connKey interface{}
|
||||||
readSize int
|
readSize int
|
||||||
pending [][]byte
|
|
||||||
pendingErr error
|
header [queHeaderSize]byte
|
||||||
|
headerRead int
|
||||||
|
payload []byte
|
||||||
|
payloadRead int
|
||||||
|
release func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
|
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
|
||||||
@ -107,44 +114,55 @@ func (reader *FrameReader) Next() ([]byte, error) {
|
|||||||
if reader == nil || reader.reader == nil || reader.queue == nil {
|
if reader == nil || reader.reader == nil || reader.queue == nil {
|
||||||
return nil, io.ErrClosedPipe
|
return nil, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(reader.pending) > 0 {
|
payload, release, err := reader.NextPooled()
|
||||||
return reader.popPending(), nil
|
if err != nil {
|
||||||
}
|
|
||||||
if reader.pendingErr != nil {
|
|
||||||
err := reader.pendingErr
|
|
||||||
reader.pendingErr = nil
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if reader.readSize <= 0 {
|
if release == nil {
|
||||||
reader.readSize = DefaultFrameReaderBufferSize
|
return payload, nil
|
||||||
}
|
|
||||||
buf := make([]byte, reader.readSize)
|
|
||||||
for {
|
|
||||||
n, readErr := reader.reader.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
parseErr := reader.queue.ParseMessageOwned(buf[:n], reader.connKey, func(msg MsgQueue) error {
|
|
||||||
reader.pending = append(reader.pending, msg.Msg)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
err := reader.normalizeReadErr(parseErr, readErr)
|
|
||||||
if len(reader.pending) > 0 {
|
|
||||||
reader.pendingErr = err
|
|
||||||
return reader.popPending(), nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if readErr != nil {
|
|
||||||
return nil, reader.normalizeReadErr(nil, readErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
owned := cloneBytes(payload)
|
||||||
|
release()
|
||||||
|
return owned, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (reader *FrameReader) popPending() []byte {
|
// NextPooled returns the next frame payload. The caller must call release when
|
||||||
next := reader.pending[0]
|
// it is non-nil.
|
||||||
reader.pending = reader.pending[1:]
|
func (reader *FrameReader) NextPooled() ([]byte, func(), error) {
|
||||||
return next
|
if reader == nil || reader.reader == nil || reader.queue == nil {
|
||||||
|
return nil, nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if err := reader.readHeader(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if err := reader.ensurePayloadBuffer(true); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if len(reader.payload) > 0 {
|
||||||
|
if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reader.finishFrame()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextView reads the next frame and exposes its payload only for the duration
|
||||||
|
// of fn.
|
||||||
|
func (reader *FrameReader) NextView(fn func(FrameView) error) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrQueueFrameHandlerNil
|
||||||
|
}
|
||||||
|
payload, release, err := reader.NextPooled()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if release != nil {
|
||||||
|
defer release()
|
||||||
|
}
|
||||||
|
return fn(FrameView{
|
||||||
|
Payload: payload,
|
||||||
|
Conn: reader.connKey,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinFrameReaderError(parseErr error, readErr error) error {
|
func joinFrameReaderError(parseErr error, readErr error) error {
|
||||||
@ -167,25 +185,121 @@ func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (reader *FrameReader) hasBufferedState() bool {
|
func (reader *FrameReader) hasBufferedState() bool {
|
||||||
if reader == nil || reader.queue == nil {
|
if reader == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
stateAny, ok := reader.queue.states.Load(reader.connKey)
|
return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
state, ok := stateAny.(*queConnState)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
state.mu.Lock()
|
|
||||||
defer state.mu.Unlock()
|
|
||||||
return len(state.buf) > 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (reader *FrameReader) clearBufferedState() {
|
func (reader *FrameReader) clearBufferedState() {
|
||||||
if reader == nil || reader.queue == nil {
|
if reader == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reader.queue.states.Delete(reader.connKey)
|
reader.headerRead = 0
|
||||||
|
reader.payloadRead = 0
|
||||||
|
if reader.release != nil {
|
||||||
|
reader.release()
|
||||||
|
}
|
||||||
|
reader.release = nil
|
||||||
|
reader.payload = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) readHeader() error {
|
||||||
|
if reader.headerRead == queHeaderSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return reader.readInto(reader.header[:], &reader.headerRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error {
|
||||||
|
if reader.payload != nil || reader.headerRead < queHeaderSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength)
|
||||||
|
if err != nil {
|
||||||
|
reader.clearBufferedState()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if header.Length == 0 {
|
||||||
|
reader.payload = []byte{}
|
||||||
|
reader.release = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payloadLen := int(header.Length)
|
||||||
|
switch {
|
||||||
|
case reader.queue.Encode && reader.queue.DecodeFunc != nil:
|
||||||
|
reader.payload = make([]byte, payloadLen)
|
||||||
|
case pooled:
|
||||||
|
buf := getFramePayloadBuffer(payloadLen)
|
||||||
|
reader.payload = buf
|
||||||
|
reader.release = func() {
|
||||||
|
putFramePayloadBuffer(buf)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
reader.payload = make([]byte, payloadLen)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) finishFrame() ([]byte, func(), error) {
|
||||||
|
payload := reader.payload
|
||||||
|
release := reader.release
|
||||||
|
reader.payload = nil
|
||||||
|
reader.payloadRead = 0
|
||||||
|
reader.release = nil
|
||||||
|
reader.headerRead = 0
|
||||||
|
if reader.queue.Encode && reader.queue.DecodeFunc != nil {
|
||||||
|
decoded := reader.queue.DecodeFunc(payload)
|
||||||
|
if release != nil {
|
||||||
|
release()
|
||||||
|
release = nil
|
||||||
|
}
|
||||||
|
payload = decoded
|
||||||
|
}
|
||||||
|
return payload, release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *FrameReader) readInto(dst []byte, offset *int) error {
|
||||||
|
for *offset < len(dst) {
|
||||||
|
maxRead := len(dst) - *offset
|
||||||
|
if reader.readSize > 0 && maxRead > reader.readSize {
|
||||||
|
maxRead = reader.readSize
|
||||||
|
}
|
||||||
|
n, err := reader.reader.Read(dst[*offset : *offset+maxRead])
|
||||||
|
if n > 0 {
|
||||||
|
*offset += n
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
if *offset == 0 {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
if *offset < len(dst) {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFramePayloadBuffer(size int) []byte {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size {
|
||||||
|
return pooled[:size]
|
||||||
|
}
|
||||||
|
return make([]byte, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putFramePayloadBuffer(buf []byte) {
|
||||||
|
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
framePayloadPool.Put(buf[:0])
|
||||||
}
|
}
|
||||||
|
|||||||
120
frameio_test.go
120
frameio_test.go
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -24,6 +25,48 @@ func (reader *chunkedReader) Read(p []byte) (int, error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type trackingReader struct {
|
||||||
|
data []byte
|
||||||
|
maxRequested int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *trackingReader) Read(p []byte) (int, error) {
|
||||||
|
if len(p) > reader.maxRequested {
|
||||||
|
reader.maxRequested = len(p)
|
||||||
|
}
|
||||||
|
if len(reader.data) == 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, reader.data)
|
||||||
|
reader.data = reader.data[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stagedReaderStep struct {
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type stagedReader struct {
|
||||||
|
steps []stagedReaderStep
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (reader *stagedReader) Read(p []byte) (int, error) {
|
||||||
|
if reader.index >= len(reader.steps) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
step := &reader.steps[reader.index]
|
||||||
|
n := copy(p, step.data)
|
||||||
|
step.data = step.data[n:]
|
||||||
|
if len(step.data) == 0 {
|
||||||
|
err := step.err
|
||||||
|
reader.index++
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestFrameWriterReaderRoundTrip(t *testing.T) {
|
func TestFrameWriterReaderRoundTrip(t *testing.T) {
|
||||||
var wire bytes.Buffer
|
var wire bytes.Buffer
|
||||||
writer := NewFrameWriter(&wire, nil)
|
writer := NewFrameWriter(&wire, nil)
|
||||||
@ -68,6 +111,83 @@ func TestFrameReaderHandlesPartialTransportReads(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderSetReadBufferSizeLimitsUnderlyingReadSize(t *testing.T) {
|
||||||
|
var wire bytes.Buffer
|
||||||
|
writer := NewFrameWriter(&wire, nil)
|
||||||
|
if err := writer.WriteFrame([]byte("hello")); err != nil {
|
||||||
|
t.Fatalf("WriteFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
source := &trackingReader{data: wire.Bytes()}
|
||||||
|
reader := NewFrameReader(source, nil)
|
||||||
|
reader.SetReadBufferSize(3)
|
||||||
|
|
||||||
|
got, err := reader.Next()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Next returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != "hello" {
|
||||||
|
t.Fatalf("unexpected frame: %q", got)
|
||||||
|
}
|
||||||
|
if got, want := source.maxRequested, 3; got != want {
|
||||||
|
t.Fatalf("max requested read size = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderNextPooledAndNextView(t *testing.T) {
|
||||||
|
queue := NewQueue()
|
||||||
|
wire := append(queue.BuildMessage([]byte("alpha")), queue.BuildMessage([]byte("beta"))...)
|
||||||
|
reader := NewFrameReader(bytes.NewReader(wire), queue)
|
||||||
|
|
||||||
|
first, release, err := reader.NextPooled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NextPooled returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(first) != "alpha" {
|
||||||
|
t.Fatalf("unexpected pooled frame: %q", first)
|
||||||
|
}
|
||||||
|
if release == nil {
|
||||||
|
t.Fatal("NextPooled should return a release func")
|
||||||
|
}
|
||||||
|
release()
|
||||||
|
|
||||||
|
var second string
|
||||||
|
if err := reader.NextView(func(view FrameView) error {
|
||||||
|
second = string(view.Payload)
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("NextView returned error: %v", err)
|
||||||
|
}
|
||||||
|
if second != "beta" {
|
||||||
|
t.Fatalf("unexpected view frame: %q", second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrameReaderPreservesPartialFrameAcrossDeadline(t *testing.T) {
|
||||||
|
queue := NewQueue()
|
||||||
|
frame := queue.BuildMessage([]byte("hello"))
|
||||||
|
reader := NewFrameReader(&stagedReader{
|
||||||
|
steps: []stagedReaderStep{
|
||||||
|
{data: append([]byte(nil), frame[:5]...), err: os.ErrDeadlineExceeded},
|
||||||
|
{data: append([]byte(nil), frame[5:]...), err: nil},
|
||||||
|
},
|
||||||
|
}, queue)
|
||||||
|
|
||||||
|
if _, _, err := reader.NextPooled(); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
t.Fatalf("expected deadline exceeded on partial frame, got %v", err)
|
||||||
|
}
|
||||||
|
got, release, err := reader.NextPooled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NextPooled after deadline returned error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != "hello" {
|
||||||
|
t.Fatalf("unexpected frame after deadline: %q", got)
|
||||||
|
}
|
||||||
|
if release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFrameWriterNilWriterFails(t *testing.T) {
|
func TestFrameWriterNilWriterFails(t *testing.T) {
|
||||||
writer := NewFrameWriter(nil, nil)
|
writer := NewFrameWriter(nil, nil)
|
||||||
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
|
if err := writer.WriteFrame([]byte("hello")); !errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user