package stario import ( "errors" "io" "sync" ) // DefaultFrameReaderBufferSize is the default transport read chunk size used by // FrameReader. const DefaultFrameReaderBufferSize = 32 * 1024 var framePayloadPool sync.Pool type frameReaderConnKey struct{} // FrameWriter adapts StarQueue framing helpers to an io.Writer. type FrameWriter struct { writer io.Writer queue *StarQueue } // NewFrameWriter creates a framing writer backed by queue. When queue is nil, a // default StarQueue is created. func NewFrameWriter(writer io.Writer, queue *StarQueue) *FrameWriter { if queue == nil { queue = NewQueue() } return &FrameWriter{ writer: writer, queue: queue, } } // WriteFrame writes one framed payload. func (writer *FrameWriter) WriteFrame(payload []byte) error { if writer == nil || writer.writer == nil || writer.queue == nil { return io.ErrClosedPipe } return writer.queue.WriteFrame(writer.writer, payload) } // WriteFrameBuffers writes one framed payload using net.Buffers when possible. func (writer *FrameWriter) WriteFrameBuffers(payload []byte) error { if writer == nil || writer.writer == nil || writer.queue == nil { return io.ErrClosedPipe } return writer.queue.WriteFrameBuffers(writer.writer, payload) } // WriteFramesBuffers writes multiple framed payloads in one batch when // possible. func (writer *FrameWriter) WriteFramesBuffers(payloads ...[]byte) error { if writer == nil || writer.writer == nil || writer.queue == nil { return io.ErrClosedPipe } return writer.queue.WriteFramesBuffers(writer.writer, payloads...) } // FrameReader adapts StarQueue parsing helpers to an io.Reader. type FrameReader struct { reader io.Reader queue *StarQueue connKey interface{} readSize int header [queHeaderSize]byte headerRead int payload []byte payloadRead int release func() } // NewFrameReader creates a framing reader backed by queue. When queue is nil, a // default StarQueue is created. func NewFrameReader(reader io.Reader, queue *StarQueue) *FrameReader { if queue == nil { queue = NewQueue() } return &FrameReader{ reader: reader, queue: queue, connKey: &frameReaderConnKey{}, readSize: DefaultFrameReaderBufferSize, } } // SetReadBufferSize updates the underlying transport read chunk size. func (reader *FrameReader) SetReadBufferSize(size int) { if reader == nil || size <= 0 { return } reader.readSize = size } // SetConnKey overrides the internal queue connection key. func (reader *FrameReader) SetConnKey(conn interface{}) error { if reader == nil { return io.ErrClosedPipe } if conn == nil { reader.connKey = &frameReaderConnKey{} return nil } if err := validateConnKey(conn); err != nil { return err } reader.connKey = conn return nil } // Next returns the next framed payload. func (reader *FrameReader) Next() ([]byte, error) { if reader == nil || reader.reader == nil || reader.queue == nil { return nil, io.ErrClosedPipe } payload, release, err := reader.NextPooled() if err != nil { return nil, err } if release == nil { return payload, nil } owned := cloneBytes(payload) release() return owned, nil } // NextPooled returns the next frame payload. The caller must call release when // it is non-nil. func (reader *FrameReader) NextPooled() ([]byte, func(), error) { if reader == nil || reader.reader == nil || reader.queue == nil { return nil, nil, io.ErrClosedPipe } if err := reader.readHeader(); err != nil { return nil, nil, err } if err := reader.ensurePayloadBuffer(true); err != nil { return nil, nil, err } if len(reader.payload) > 0 { if err := reader.readInto(reader.payload, &reader.payloadRead); err != nil { return nil, nil, err } } return reader.finishFrame() } // NextView reads the next frame and exposes its payload only for the duration // of fn. func (reader *FrameReader) NextView(fn func(FrameView) error) error { if fn == nil { return ErrQueueFrameHandlerNil } payload, release, err := reader.NextPooled() if err != nil { return err } if release != nil { defer release() } return fn(FrameView{ Payload: payload, Conn: reader.connKey, }) } func joinFrameReaderError(parseErr error, readErr error) error { switch { case parseErr == nil: return readErr case readErr == nil: return parseErr default: return errors.Join(parseErr, readErr) } } func (reader *FrameReader) normalizeReadErr(parseErr error, readErr error) error { if errors.Is(readErr, io.EOF) && reader.hasBufferedState() { reader.clearBufferedState() readErr = io.ErrUnexpectedEOF } return joinFrameReaderError(parseErr, readErr) } func (reader *FrameReader) hasBufferedState() bool { if reader == nil { return false } return reader.headerRead > 0 || reader.payloadRead > 0 || len(reader.payload) > 0 } func (reader *FrameReader) clearBufferedState() { if reader == nil { return } reader.headerRead = 0 reader.payloadRead = 0 if reader.release != nil { reader.release() } reader.release = nil reader.payload = nil } func (reader *FrameReader) readHeader() error { if reader.headerRead == queHeaderSize { return nil } return reader.readInto(reader.header[:], &reader.headerRead) } func (reader *FrameReader) ensurePayloadBuffer(pooled bool) error { if reader.payload != nil || reader.headerRead < queHeaderSize { return nil } header, err := parseHeaderBytes(reader.header[:], reader.queue.maxLength) if err != nil { reader.clearBufferedState() return err } if header.Length == 0 { reader.payload = []byte{} reader.release = nil return nil } payloadLen := int(header.Length) switch { case reader.queue.Encode && reader.queue.DecodeFunc != nil: reader.payload = make([]byte, payloadLen) case pooled: buf := getFramePayloadBuffer(payloadLen) reader.payload = buf reader.release = func() { putFramePayloadBuffer(buf) } default: reader.payload = make([]byte, payloadLen) } return nil } func (reader *FrameReader) finishFrame() ([]byte, func(), error) { payload := reader.payload release := reader.release reader.payload = nil reader.payloadRead = 0 reader.release = nil reader.headerRead = 0 if reader.queue.Encode && reader.queue.DecodeFunc != nil { decoded := reader.queue.DecodeFunc(payload) if release != nil { release() release = nil } payload = decoded } return payload, release, nil } func (reader *FrameReader) readInto(dst []byte, offset *int) error { for *offset < len(dst) { maxRead := len(dst) - *offset if reader.readSize > 0 && maxRead > reader.readSize { maxRead = reader.readSize } n, err := reader.reader.Read(dst[*offset : *offset+maxRead]) if n > 0 { *offset += n } if err != nil { if errors.Is(err, io.EOF) { if *offset == 0 { return io.EOF } if *offset < len(dst) { return io.ErrUnexpectedEOF } } return err } if n == 0 { return io.ErrNoProgress } } return nil } func getFramePayloadBuffer(size int) []byte { if size <= 0 { return nil } if pooled, ok := framePayloadPool.Get().([]byte); ok && cap(pooled) >= size { return pooled[:size] } return make([]byte, size) } func putFramePayloadBuffer(buf []byte) { if cap(buf) == 0 || cap(buf) > 32*1024*1024 { return } framePayloadPool.Put(buf[:0]) }