stario/frameio.go

306 lines
7.1 KiB
Go
Raw Permalink Normal View History

package stario
import (
"errors"
"io"
"sync"
)
// DefaultFrameReaderBufferSize is the default transport read chunk size used by
// FrameReader.
const DefaultFrameReaderBufferSize = 32 * 1024
var framePayloadPool sync.Pool
type frameReaderConnKey struct{}
// FrameWriter adapts StarQueue framing helpers to an io.Writer.
type FrameWriter struct {
writer io.Writer
queue *StarQueue
}
// NewFrameWriter creates a framing writer backed by queue. When queue is nil, a
// default StarQueue is created.
func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter {
if queue == nil {
queue = NewQueue()
}
return &FrameWriter{
writer: writer,
queue: queue,
}
}
// WriteFrame writes one framed payload.
func (writer *FrameWriter) WriteFrame(payload []byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFrame(writer.writer, payload)
}
// WriteFrameBuffers writes one framed payload using net.Buffers when possible.
func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFrameBuffers(writer.writer, payload)
}
// WriteFramesBuffers writes multiple framed payloads in one batch when
// possible.
func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error {
if writer == nil || writer.writer == nil || writer.queue == nil {
return io.ErrClosedPipe
}
return writer.queue.WriteFramesBuffers(writer.writer, payloads...)
}
// FrameReader adapts StarQueue parsing helpers to an io.Reader.
type FrameReader struct {
reader io.Reader
queue *StarQueue
connKey interface{}
readSize int
header [queHeaderSize]byte
headerRead int
payload []byte
payloadRead int
release func()
}
// NewFrameReader creates a framing reader backed by queue. When queue is nil, a
// default StarQueue is created.
func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader {
if queue == nil {
queue = NewQueue()
}
return &FrameReader{
reader: reader,
queue: queue,
connKey: &frameReaderConnKey{},
readSize: DefaultFrameReaderBufferSize,
}
}
// SetReadBufferSize updates the underlying transport read chunk size.
func (reader *FrameReader) SetReadBufferSize(size int) {
if reader == nil || size <= 0 {
return
}
reader.readSize = size
}
// SetConnKey overrides the internal queue connection key.
func (reader *FrameReader) SetConnKey(conn interface{}) error {
if reader == nil {
return io.ErrClosedPipe
}
if conn == nil {
reader.connKey = &frameReaderConnKey{}
return nil
}
if err := validateConnKey(conn); err != nil {
return err
}
reader.connKey = conn
return nil
}
// Next returns the next framed payload.
func (reader *FrameReader) Next() ([]byte, error) {
if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, io.ErrClosedPipe
}
payload, release, err := reader.NextPooled()
if err != nil {
return nil, err
}
if release == nil {
return payload, nil
}
owned := cloneBytes(payload)
release()
return owned, nil
}
// NextPooled returns the next frame payload. The caller must call release when
// it is non-nil.
func (reader *FrameReader) NextPooled() ([]byte, func(), error) {
if reader == nil || reader.reader == nil || reader.queue == nil {
return nil, nil, io.ErrClosedPipe
}
if err := reader.readHeader(); err != nil {
return nil, nil, err
}
if err := reader.ensurePayloadBuffer(true); err != nil {
return nil, nil, err
}
if len(reader.payload) > 0 {
if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil {
return nil, nil, err
}
}
return reader.finishFrame()
}
// NextView reads the next frame and exposes its payload only for the duration
// of fn.
func (reader *FrameReader) NextView(fn func(FrameView) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
payload, release, err := reader.NextPooled()
if err != nil {
return err
}
if release != nil {
defer release()
}
return fn(FrameView{
Payload: payload,
Conn: reader.connKey,
})
}
func joinFrameReaderError(parseErr error, readErr error) error {
switch {
case parseErr == nil:
return readErr
case readErr == nil:
return parseErr
default:
return errors.Join(parseErr, readErr)
}
}
func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error {
if errors.Is(readErr, io.EOF) && reader.hasBufferedState() {
reader.clearBufferedState()
readErr = io.ErrUnexpectedEOF
}
return joinFrameReaderError(parseErr, readErr)
}
func (reader *FrameReader) hasBufferedState() bool {
if reader == nil {
return false
}
return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0
}
func (reader *FrameReader) clearBufferedState() {
if reader == nil {
return
}
reader.headerRead = 0
reader.payloadRead = 0
if reader.release != nil {
reader.release()
}
reader.release = nil
reader.payload = nil
}
func (reader *FrameReader) readHeader() error {
if reader.headerRead == queHeaderSize {
return nil
}
return reader.readInto(reader.header[:], &reader.headerRead)
}
func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error {
if reader.payload != nil || reader.headerRead < queHeaderSize {
return nil
}
header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength)
if err != nil {
reader.clearBufferedState()
return err
}
if header.Length == 0 {
reader.payload = []byte{}
reader.release = nil
return nil
}
payloadLen := int(header.Length)
switch {
case reader.queue.Encode && reader.queue.DecodeFunc != nil:
reader.payload = make([]byte, payloadLen)
case pooled:
buf := getFramePayloadBuffer(payloadLen)
reader.payload = buf
reader.release = func() {
putFramePayloadBuffer(buf)
}
default:
reader.payload = make([]byte, payloadLen)
}
return nil
}
func (reader *FrameReader) finishFrame() ([]byte, func(), error) {
payload := reader.payload
release := reader.release
reader.payload = nil
reader.payloadRead = 0
reader.release = nil
reader.headerRead = 0
if reader.queue.Encode && reader.queue.DecodeFunc != nil {
decoded := reader.queue.DecodeFunc(payload)
if release != nil {
release()
release = nil
}
payload = decoded
}
return payload, release, nil
}
func (reader *FrameReader) readInto(dst []byte, offset *int) error {
for *offset < len(dst) {
maxRead := len(dst) - *offset
if reader.readSize > 0 && maxRead > reader.readSize {
maxRead = reader.readSize
}
n, err := reader.reader.Read(dst[*offset : *offset+maxRead])
if n > 0 {
*offset += n
}
if err != nil {
if errors.Is(err, io.EOF) {
if *offset == 0 {
return io.EOF
}
if *offset < len(dst) {
return io.ErrUnexpectedEOF
}
}
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
func getFramePayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putFramePayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
framePayloadPool.Put(buf[:0])
}