From a986fafa0987861f8a61916c59220a50f4b27a94 Mon Sep 17 00:00:00 2001 From: starainrt Date: Sun, 18 Aug 2024 17:17:20 +0800 Subject: [PATCH] move starqueue from starnet to stario --- que.go | 325 ++++++++++++++++++++++++++++++++++++++++++++++++++++ que_test.go | 42 +++++++ 2 files changed, 367 insertions(+) create mode 100644 que.go create mode 100644 que_test.go diff --git a/que.go b/que.go new file mode 100644 index 0000000..8c02643 --- /dev/null +++ b/que.go @@ -0,0 +1,325 @@ +package stario + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "os" + "sync" + "time" +) + +var ErrDeadlineExceeded error = errors.New("deadline exceeded") + +// 识别头 +var header = []byte{11, 27, 19, 96, 12, 25, 02, 20} + +// MsgQueue 为基本的信息单位 +type MsgQueue struct { + ID uint16 + Msg []byte + Conn interface{} +} + +// StarQueue 为流数据中的消息队列分发 +type StarQueue struct { + maxLength uint32 + count int64 + Encode bool + msgID uint16 + msgPool chan MsgQueue + unFinMsg sync.Map + lastID int //= -1 + ctx context.Context + cancel context.CancelFunc + duration time.Duration + EncodeFunc func([]byte) []byte + DecodeFunc func([]byte) []byte + //restoreMu sync.Mutex +} + +func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue { + var q StarQueue + q.Encode = false + q.count = count + q.maxLength = maxMsgLength + q.msgPool = make(chan MsgQueue, count) + if ctx == nil { + q.ctx, q.cancel = context.WithCancel(context.Background()) + } else { + q.ctx, q.cancel = context.WithCancel(ctx) + } + q.duration = 0 + return &q +} +func NewQueueWithCount(count int64) *StarQueue { + return NewQueueCtx(nil, count, 0) +} + +// NewQueue 建立一个新消息队列 +func NewQueue() *StarQueue { + return NewQueueWithCount(32) +} + +// Uint32ToByte 4位uint32转byte +func Uint32ToByte(src uint32) []byte { + res := make([]byte, 4) + res[3] = uint8(src) + res[2] = uint8(src >> 8) + res[1] = uint8(src >> 16) + res[0] = uint8(src >> 24) + return res +} + +// ByteToUint32 byte转4位uint32 +func ByteToUint32(src []byte) uint32 { + var res uint32 + buffer := bytes.NewBuffer(src) + binary.Read(buffer, binary.BigEndian, &res) + return res +} + +// Uint16ToByte 2位uint16转byte +func Uint16ToByte(src uint16) []byte { + res := make([]byte, 2) + res[1] = uint8(src) + res[0] = uint8(src >> 8) + return res +} + +// ByteToUint16 用于byte转uint16 +func ByteToUint16(src []byte) uint16 { + var res uint16 + buffer := bytes.NewBuffer(src) + binary.Read(buffer, binary.BigEndian, &res) + return res +} + +// BuildMessage 生成编码后的信息用于发送 +func (q *StarQueue) BuildMessage(src []byte) []byte { + var buff bytes.Buffer + q.msgID++ + if q.Encode { + src = q.EncodeFunc(src) + } + length := uint32(len(src)) + buff.Write(header) + buff.Write(Uint32ToByte(length)) + buff.Write(Uint16ToByte(q.msgID)) + buff.Write(src) + return buff.Bytes() +} + +// BuildHeader 生成编码后的Header用于发送 +func (q *StarQueue) BuildHeader(length uint32) []byte { + var buff bytes.Buffer + q.msgID++ + buff.Write(header) + buff.Write(Uint32ToByte(length)) + buff.Write(Uint16ToByte(q.msgID)) + return buff.Bytes() +} + +type unFinMsg struct { + ID uint16 + LengthRecv uint32 + // HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id + HeaderMsg []byte + RecvMsg []byte +} + +func (q *StarQueue) push2list(msg MsgQueue) { + q.msgPool <- msg +} + +// ParseMessage 用于解析收到的msg信息 +func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error { + return q.parseMessage(msg, conn) +} + +// parseMessage 用于解析收到的msg信息 +func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error { + tmp, ok := q.unFinMsg.Load(conn) + if ok { //存在未完成的信息 + lastMsg := tmp.(*unFinMsg) + headerLen := len(lastMsg.HeaderMsg) + if headerLen < 14 { //未完成头标题 + //传输的数据不能填充header头 + if len(msg) < 14-headerLen { + //加入header头并退出 + lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg) + q.unFinMsg.Store(conn, lastMsg) + return nil + } + //获取14字节完整的header + header := msg[0 : 14-headerLen] + lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header) + //检查收到的header是否为认证header + //若不是,丢弃并重新来过 + if !checkHeader(lastMsg.HeaderMsg[0:8]) { + q.unFinMsg.Delete(conn) + if len(msg) == 0 { + return nil + } + return q.parseMessage(msg, conn) + } + //获得本数据包长度 + lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12]) + if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength { + q.unFinMsg.Delete(conn) + return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength) + } + //获得本数据包ID + lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14]) + //存入列表 + q.unFinMsg.Store(conn, lastMsg) + msg = msg[14-headerLen:] + if uint32(len(msg)) < lastMsg.LengthRecv { + lastMsg.RecvMsg = msg + q.unFinMsg.Store(conn, lastMsg) + return nil + } + if uint32(len(msg)) >= lastMsg.LengthRecv { + lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv] + if q.Encode { + lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg) + } + msg = msg[lastMsg.LengthRecv:] + storeMsg := MsgQueue{ + ID: lastMsg.ID, + Msg: lastMsg.RecvMsg, + Conn: conn, + } + //q.restoreMu.Lock() + q.push2list(storeMsg) + //q.restoreMu.Unlock() + q.unFinMsg.Delete(conn) + return q.parseMessage(msg, conn) + } + } else { + lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg) + if lastID < 0 { + q.unFinMsg.Delete(conn) + return q.parseMessage(msg, conn) + } + if len(msg) >= lastID { + lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID]) + if q.Encode { + lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg) + } + storeMsg := MsgQueue{ + ID: lastMsg.ID, + Msg: lastMsg.RecvMsg, + Conn: conn, + } + q.push2list(storeMsg) + q.unFinMsg.Delete(conn) + if len(msg) == lastID { + return nil + } + msg = msg[lastID:] + return q.parseMessage(msg, conn) + } + lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg) + q.unFinMsg.Store(conn, lastMsg) + return nil + } + } + if len(msg) == 0 { + return nil + } + var start int + if start = searchHeader(msg); start == -1 { + return errors.New("data format error") + } + msg = msg[start:] + lastMsg := unFinMsg{} + q.unFinMsg.Store(conn, &lastMsg) + return q.parseMessage(msg, conn) +} + +func checkHeader(msg []byte) bool { + if len(msg) != 8 { + return false + } + for k, v := range msg { + if v != header[k] { + return false + } + } + return true +} + +func searchHeader(msg []byte) int { + if len(msg) < 8 { + return 0 + } + for k, v := range msg { + find := 0 + if v == header[0] { + for k2, v2 := range header { + if msg[k+k2] == v2 { + find++ + } else { + break + } + } + if find == 8 { + return k + } + } + } + return -1 +} + +func bytesMerge(src ...[]byte) []byte { + var buff bytes.Buffer + for _, v := range src { + buff.Write(v) + } + return buff.Bytes() +} + +// Restore 获取收到的信息 +func (q *StarQueue) Restore() (MsgQueue, error) { + if q.duration.Seconds() == 0 { + q.duration = 86400 * time.Second + } + for { + select { + case <-q.ctx.Done(): + return MsgQueue{}, errors.New("Stoped By External Function Call") + case <-time.After(q.duration): + if q.duration != 0 { + return MsgQueue{}, ErrDeadlineExceeded + } + case data, ok := <-q.msgPool: + if !ok { + return MsgQueue{}, os.ErrClosed + } + return data, nil + } + } +} + +// RestoreOne 获取收到的一个信息 +// 兼容性修改 +func (q *StarQueue) RestoreOne() (MsgQueue, error) { + return q.Restore() +} + +// Stop 立即停止Restore +func (q *StarQueue) Stop() { + q.cancel() +} + +// RestoreDuration Restore最大超时时间 +func (q *StarQueue) RestoreDuration(tm time.Duration) { + q.duration = tm +} + +func (q *StarQueue) RestoreChan() <-chan MsgQueue { + return q.msgPool +} diff --git a/que_test.go b/que_test.go new file mode 100644 index 0000000..d4a693a --- /dev/null +++ b/que_test.go @@ -0,0 +1,42 @@ +package stario + +import ( + "fmt" + "testing" + "time" +) + +func Test_QueSpeed(t *testing.T) { + que := NewQueueWithCount(0) + stop := make(chan struct{}, 1) + que.RestoreDuration(time.Second * 10) + var count int64 + go func() { + for { + select { + case <-stop: + //fmt.Println(count) + return + default: + } + _, err := que.RestoreOne() + if err == nil { + count++ + } + } + }() + cp := 0 + stoped := time.After(time.Second * 10) + data := que.BuildMessage([]byte("hello")) + for { + select { + case <-stoped: + fmt.Println(count, cp) + stop <- struct{}{} + return + default: + que.ParseMessage(data, "lala") + cp++ + } + } +}