stario/que_frame.go

209 lines
5.1 KiB
Go
Raw Permalink Normal View History

package stario
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// BuildMessage builds one full frame and panics if the payload is too large to
// fit in the framing format.
//
// New code should prefer BuildMessageErr when it needs an explicit error path.
func (q *StarQueue) BuildMessage(src []byte) []byte {
frame, err := q.BuildMessageErr(src)
if err != nil {
panic(err)
}
return frame
}
// BuildMessageErr builds one full frame and returns an explicit error when the
// payload exceeds the 32-bit frame length.
func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) {
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return nil, err
}
header := q.BuildHeader(length)
frame := make([]byte, 0, len(header)+len(payload))
frame = append(frame, header...)
frame = append(frame, payload...)
return frame, nil
}
// WriteFrame writes one framed payload directly to w without building an
// intermediate full-frame slice first.
func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error {
if w == nil {
return io.ErrClosedPipe
}
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
var header [queHeaderSize]byte
writeHeaderBytes(header[:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
if err := writeFull(w, header[:]); err != nil {
return err
}
return writeFull(w, payload)
}
// WriteFrameBuffers writes one framed payload using net.Buffers so callers can
// opt into gather writes when the underlying writer supports it well.
func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error {
if w == nil {
return io.ErrClosedPipe
}
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
var header [queHeaderSize]byte
writeHeaderBytes(header[:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
buffers := net.Buffers{header[:], payload}
_, err = buffers.WriteTo(w)
return err
}
// WriteFramesBuffers writes multiple framed payloads using one net.Buffers
// batch so callers can reduce write calls on stream transports.
func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error {
if w == nil {
return io.ErrClosedPipe
}
if len(payloads) == 0 {
return nil
}
buffers := make(net.Buffers, 0, len(payloads)*2)
headers := make([][queHeaderSize]byte, len(payloads))
for i, src := range payloads {
payload := src
if q.Encode && q.EncodeFunc != nil {
payload = q.EncodeFunc(payload)
}
length, err := payloadSizeToUint32(uint64(len(payload)))
if err != nil {
return err
}
writeHeaderBytes(headers[i][:], queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
buffers = append(buffers, headers[i][:], payload)
}
_, err := buffers.WriteTo(w)
return err
}
// BuildHeader 生成编码后的Header用于发送。
func (q *StarQueue) BuildHeader(length uint32) []byte {
return buildHeaderBytes(queHeader{
Length: length,
Version: queVersionV1,
Flags: queSupportedFlags,
})
}
func buildHeaderBytes(header queHeader) []byte {
buf := make([]byte, queHeaderSize)
writeHeaderBytes(buf, header)
return buf
}
func writeHeaderBytes(dst []byte, header queHeader) {
if len(dst) < queHeaderSize {
return
}
copy(dst[:queMagicSize], queMagic)
binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length)
dst[12] = header.Version
dst[13] = header.Flags
}
func payloadSizeToUint32(size uint64) (uint32, error) {
const maxFramePayload = ^uint32(0)
if size > uint64(maxFramePayload) {
return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload))
}
return uint32(size), nil
}
func writeFull(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if n > 0 {
data = data[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) {
if len(src) < queHeaderSize {
return queHeader{}, ErrQueueDataFormat
}
if !equalMagic(src[:queMagicSize]) {
return queHeader{}, ErrQueueDataFormat
}
header := queHeader{
Length: ByteToUint32(src[queMagicSize : queMagicSize+4]),
Version: src[12],
Flags: src[13],
}
if header.Version != queVersionV1 {
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version)
}
if header.Flags != queSupportedFlags {
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags)
}
if maxLength != 0 && header.Length > maxLength {
return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength)
}
return header, nil
}
func equalMagic(src []byte) bool {
if len(src) != queMagicSize {
return false
}
for i, b := range src {
if b != queMagic[i] {
return false
}
}
return true
}