209 lines
5.1 KiB
Go
209 lines
5.1 KiB
Go
|
|
package stario
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/binary"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
)
|
||
|
|
|
||
|
|
// BuildMessage builds one full frame and panics if the payload is too large to
|
||
|
|
// fit in the framing format.
|
||
|
|
//
|
||
|
|
// New code should prefer BuildMessageErr when it needs an explicit error path.
|
||
|
|
func (q *StarQueue) BuildMessage(src []byte) []byte {
|
||
|
|
frame, err := q.BuildMessageErr(src)
|
||
|
|
if err != nil {
|
||
|
|
panic(err)
|
||
|
|
}
|
||
|
|
return frame
|
||
|
|
}
|
||
|
|
|
||
|
|
// BuildMessageErr builds one full frame and returns an explicit error when the
|
||
|
|
// payload exceeds the 32-bit frame length.
|
||
|
|
func (q *StarQueue) BuildMessageErr(src []byte) ([]byte, error) {
|
||
|
|
payload := src
|
||
|
|
if q.Encode && q.EncodeFunc != nil {
|
||
|
|
payload = q.EncodeFunc(payload)
|
||
|
|
}
|
||
|
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
header := q.BuildHeader(length)
|
||
|
|
frame := make([]byte, 0, len(header)+len(payload))
|
||
|
|
frame = append(frame, header...)
|
||
|
|
frame = append(frame, payload...)
|
||
|
|
return frame, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteFrame writes one framed payload directly to w without building an
|
||
|
|
// intermediate full-frame slice first.
|
||
|
|
func (q *StarQueue) WriteFrame(w io.Writer, src []byte) error {
|
||
|
|
if w == nil {
|
||
|
|
return io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
payload := src
|
||
|
|
if q.Encode && q.EncodeFunc != nil {
|
||
|
|
payload = q.EncodeFunc(payload)
|
||
|
|
}
|
||
|
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
var header [queHeaderSize]byte
|
||
|
|
writeHeaderBytes(header[:], queHeader{
|
||
|
|
Length: length,
|
||
|
|
Version: queVersionV1,
|
||
|
|
Flags: queSupportedFlags,
|
||
|
|
})
|
||
|
|
if err := writeFull(w, header[:]); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return writeFull(w, payload)
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteFrameBuffers writes one framed payload using net.Buffers so callers can
|
||
|
|
// opt into gather writes when the underlying writer supports it well.
|
||
|
|
func (q *StarQueue) WriteFrameBuffers(w io.Writer, src []byte) error {
|
||
|
|
if w == nil {
|
||
|
|
return io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
payload := src
|
||
|
|
if q.Encode && q.EncodeFunc != nil {
|
||
|
|
payload = q.EncodeFunc(payload)
|
||
|
|
}
|
||
|
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
var header [queHeaderSize]byte
|
||
|
|
writeHeaderBytes(header[:], queHeader{
|
||
|
|
Length: length,
|
||
|
|
Version: queVersionV1,
|
||
|
|
Flags: queSupportedFlags,
|
||
|
|
})
|
||
|
|
buffers := net.Buffers{header[:], payload}
|
||
|
|
_, err = buffers.WriteTo(w)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteFramesBuffers writes multiple framed payloads using one net.Buffers
|
||
|
|
// batch so callers can reduce write calls on stream transports.
|
||
|
|
func (q *StarQueue) WriteFramesBuffers(w io.Writer, payloads ...[]byte) error {
|
||
|
|
if w == nil {
|
||
|
|
return io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if len(payloads) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
buffers := make(net.Buffers, 0, len(payloads)*2)
|
||
|
|
headers := make([][queHeaderSize]byte, len(payloads))
|
||
|
|
for i, src := range payloads {
|
||
|
|
payload := src
|
||
|
|
if q.Encode && q.EncodeFunc != nil {
|
||
|
|
payload = q.EncodeFunc(payload)
|
||
|
|
}
|
||
|
|
length, err := payloadSizeToUint32(uint64(len(payload)))
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
writeHeaderBytes(headers[i][:], queHeader{
|
||
|
|
Length: length,
|
||
|
|
Version: queVersionV1,
|
||
|
|
Flags: queSupportedFlags,
|
||
|
|
})
|
||
|
|
buffers = append(buffers, headers[i][:], payload)
|
||
|
|
}
|
||
|
|
_, err := buffers.WriteTo(w)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// BuildHeader 生成编码后的Header用于发送。
|
||
|
|
func (q *StarQueue) BuildHeader(length uint32) []byte {
|
||
|
|
return buildHeaderBytes(queHeader{
|
||
|
|
Length: length,
|
||
|
|
Version: queVersionV1,
|
||
|
|
Flags: queSupportedFlags,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildHeaderBytes(header queHeader) []byte {
|
||
|
|
buf := make([]byte, queHeaderSize)
|
||
|
|
writeHeaderBytes(buf, header)
|
||
|
|
return buf
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeHeaderBytes(dst []byte, header queHeader) {
|
||
|
|
if len(dst) < queHeaderSize {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
copy(dst[:queMagicSize], queMagic)
|
||
|
|
binary.BigEndian.PutUint32(dst[queMagicSize:queMagicSize+4], header.Length)
|
||
|
|
dst[12] = header.Version
|
||
|
|
dst[13] = header.Flags
|
||
|
|
}
|
||
|
|
|
||
|
|
func payloadSizeToUint32(size uint64) (uint32, error) {
|
||
|
|
const maxFramePayload = ^uint32(0)
|
||
|
|
if size > uint64(maxFramePayload) {
|
||
|
|
return 0, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, size, uint64(maxFramePayload))
|
||
|
|
}
|
||
|
|
return uint32(size), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeFull(w io.Writer, data []byte) error {
|
||
|
|
for len(data) > 0 {
|
||
|
|
n, err := w.Write(data)
|
||
|
|
if n > 0 {
|
||
|
|
data = data[n:]
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if n == 0 {
|
||
|
|
return io.ErrNoProgress
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseHeaderBytes(src []byte, maxLength uint32) (queHeader, error) {
|
||
|
|
if len(src) < queHeaderSize {
|
||
|
|
return queHeader{}, ErrQueueDataFormat
|
||
|
|
}
|
||
|
|
if !equalMagic(src[:queMagicSize]) {
|
||
|
|
return queHeader{}, ErrQueueDataFormat
|
||
|
|
}
|
||
|
|
|
||
|
|
header := queHeader{
|
||
|
|
Length: ByteToUint32(src[queMagicSize : queMagicSize+4]),
|
||
|
|
Version: src[12],
|
||
|
|
Flags: src[13],
|
||
|
|
}
|
||
|
|
|
||
|
|
if header.Version != queVersionV1 {
|
||
|
|
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedVersion, header.Version)
|
||
|
|
}
|
||
|
|
if header.Flags != queSupportedFlags {
|
||
|
|
return queHeader{}, fmt.Errorf("%w: %d", ErrQueueUnsupportedFlags, header.Flags)
|
||
|
|
}
|
||
|
|
if maxLength != 0 && header.Length > maxLength {
|
||
|
|
return queHeader{}, fmt.Errorf("%w: %d > %d", ErrQueueMessageTooLarge, header.Length, maxLength)
|
||
|
|
}
|
||
|
|
|
||
|
|
return header, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func equalMagic(src []byte) bool {
|
||
|
|
if len(src) != queMagicSize {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
for i, b := range src {
|
||
|
|
if b != queMagic[i] {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return true
|
||
|
|
}
|