stario/que_parse.go

238 lines
5.3 KiB
Go
Raw Permalink Normal View History

package stario
import (
"bytes"
"fmt"
"reflect"
)
// ParseMessage 用于解析收到的msg信息。
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn, false, func(payload []byte) error {
return q.push2list(MsgQueue{
Msg: payload,
Conn: conn,
})
})
}
// ParseMessageView parses frames and exposes each payload to fn without
// forcing StarQueue to clone it first.
//
// The provided payload is only valid during the callback. If q uses the legacy
// DecodeFunc path, StarQueue still has to allocate a decoded payload first.
func (q *StarQueue) ParseMessageView(msg []byte, conn interface{}, fn func(FrameView) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
return q.parseMessage(msg, conn, true, func(payload []byte) error {
return fn(FrameView{
Payload: payload,
Conn: conn,
})
})
}
// ParseMessageOwned parses frames and emits owned payload copies to fn without
// routing them through RestoreChan.
//
// Compared with ParseMessage, this keeps StarQueue state handling the same but
// lets callers decide how to dispatch parsed messages themselves.
func (q *StarQueue) ParseMessageOwned(msg []byte, conn interface{}, fn func(MsgQueue) error) error {
if fn == nil {
return ErrQueueFrameHandlerNil
}
frames := make([]MsgQueue, 0, 1)
parseErr := q.parseMessage(msg, conn, false, func(payload []byte) error {
frames = append(frames, MsgQueue{
Msg: payload,
Conn: conn,
})
return nil
})
for _, frame := range frames {
if err := fn(frame); err != nil {
if parseErr != nil {
return fmt.Errorf("%v: %w", parseErr, err)
}
return err
}
}
return parseErr
}
func (q *StarQueue) parseMessage(msg []byte, conn interface{}, borrowPayload bool, emit func([]byte) error) error {
state, err := q.connState(conn)
if err != nil {
return err
}
var firstErr error
for {
payload, ok, err := q.nextPayload(state, conn, msg, borrowPayload)
msg = nil
if err != nil && firstErr == nil {
firstErr = err
}
if !ok {
break
}
if err := emit(payload); err != nil {
if firstErr != nil {
return fmt.Errorf("%v: %w", firstErr, err)
}
return err
}
}
return firstErr
}
func (q *StarQueue) nextPayload(state *queConnState, conn interface{}, msg []byte, borrowPayload bool) ([]byte, bool, error) {
state.mu.Lock()
defer state.mu.Unlock()
if len(msg) != 0 {
state.buf = append(state.buf, msg...)
}
var firstErr error
for {
synced, err := syncFrameStart(&state.buf)
if err != nil && firstErr == nil {
firstErr = err
}
if !synced {
if len(state.buf) == 0 {
q.states.Delete(conn)
}
return nil, false, firstErr
}
if len(state.buf) < queHeaderSize {
return nil, false, firstErr
}
header, err := parseHeaderBytes(state.buf[:queHeaderSize], q.maxLength)
if err != nil {
if firstErr == nil {
firstErr = err
}
state.buf = shrinkBuffer(state.buf[1:])
continue
}
frameLen := queHeaderSize + int(header.Length)
if len(state.buf) < frameLen {
return nil, false, firstErr
}
payload, rest := extractPayload(state.buf, frameLen, borrowPayload && !(q.Encode && q.DecodeFunc != nil))
state.buf = rest
if q.Encode && q.DecodeFunc != nil {
payload = q.DecodeFunc(payload)
}
if len(state.buf) == 0 {
q.states.Delete(conn)
}
return payload, true, firstErr
}
}
func extractPayload(buf []byte, frameLen int, borrowPayload bool) ([]byte, []byte) {
payload := buf[queHeaderSize:frameLen]
if !borrowPayload {
return cloneBytes(payload), shrinkBuffer(buf[frameLen:])
}
if frameLen == len(buf) {
return payload, nil
}
return payload, cloneBytes(buf[frameLen:])
}
func (q *StarQueue) push2list(msg MsgQueue) error {
select {
case <-q.ctx.Done():
return q.ctx.Err()
default:
}
q.sendMu.RLock()
defer q.sendMu.RUnlock()
select {
case <-q.ctx.Done():
return q.ctx.Err()
case q.msgPool <- msg:
return nil
}
}
func validateConnKey(conn interface{}) error {
if conn == nil {
return ErrQueueConnKeyNil
}
typ := reflect.TypeOf(conn)
if typ != nil && !typ.Comparable() {
return ErrQueueConnKeyInvalid
}
return nil
}
func (q *StarQueue) connState(conn interface{}) (*queConnState, error) {
if err := validateConnKey(conn); err != nil {
return nil, err
}
state, _ := q.states.LoadOrStore(conn, &queConnState{})
return state.(*queConnState), nil
}
func syncFrameStart(buf *[]byte) (bool, error) {
if len(*buf) == 0 {
return false, nil
}
if len(*buf) >= queMagicSize && equalMagic((*buf)[:queMagicSize]) {
return true, nil
}
idx := bytes.Index(*buf, queMagic)
if idx == 0 {
return true, nil
}
if idx > 0 {
*buf = cloneBytes((*buf)[idx:])
return true, ErrQueueDataFormat
}
keep := trailingMagicPrefixLen(*buf)
if keep == len(*buf) {
return false, nil
}
if keep > 0 {
*buf = cloneBytes((*buf)[len(*buf)-keep:])
return false, ErrQueueDataFormat
}
*buf = (*buf)[:0]
return false, ErrQueueDataFormat
}
func trailingMagicPrefixLen(buf []byte) int {
max := len(buf)
if max > queMagicSize-1 {
max = queMagicSize - 1
}
for keep := max; keep > 0; keep-- {
if bytes.Equal(buf[len(buf)-keep:], queMagic[:keep]) {
return keep
}
}
return 0
}
func shrinkBuffer(buf []byte) []byte {
if len(buf) == 0 {
return nil
}
if cap(buf) > len(buf)*4 {
return cloneBytes(buf)
}
return buf
}