Compare commits

...

11 Commits

1593
curl.go

File diff suppressed because it is too large Load Diff

@ -0,0 +1,464 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestUrlEncodeRaw(t *testing.T) {
input := "hello world!@#$%^&*()_+-=~`"
expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60"
result := UrlEncodeRaw(input)
if result != expected {
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected)
}
}
func TestUrlEncode(t *testing.T) {
input := "hello world!@#$%^&*()_+-=~`"
expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60`
result := UrlEncode(input)
if result != expected {
t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected)
}
}
func TestUrlDecode(t *testing.T) {
input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60"
expected := "hello world!@#$%^&*()_+-=~`"
result, err := UrlDecode(input)
if err != nil {
t.Errorf("UrlDecode(%q) returned error: %v", input, err)
}
if result != expected {
t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected)
}
// Test for error case
invalidInput := "%zz"
_, err = UrlDecode(invalidInput)
if err == nil {
t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput)
}
}
func TestBuildPostForm_WithValidInput(t *testing.T) {
input := map[string]string{
"key1": "value1",
"key2": "value2",
}
expected := []byte("key1=value1&key2=value2")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestBuildPostForm_WithEmptyInput(t *testing.T) {
input := map[string]string{}
expected := []byte("")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestBuildPostForm_WithNilInput(t *testing.T) {
var input map[string]string
expected := []byte("")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestGetRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestPostRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
t.Errorf("Expected 'POST', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Post(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestOptionsRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodOptions {
t.Errorf("Expected 'OPTIONS', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Options(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestPutRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPut {
t.Errorf("Expected 'PUT', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Put(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestDeleteRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
t.Errorf("Expected 'DELETE', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Delete(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestHeadRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodHead {
t.Errorf("Expected 'HEAD', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Head(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body == "OK" {
t.Errorf("Expected , got %v", body)
}
}
func TestPatchRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPatch {
t.Errorf("Expected 'PATCH', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Patch(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestTraceRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodTrace {
t.Errorf("Expected 'TRACE', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Trace(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestConnectRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
t.Errorf("Expected 'CONNECT', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Connect(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestMethodReturnsCorrectValue(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("GET")
if req.Method() != "GET" {
t.Errorf("Expected 'GET', got %v", req.Method())
}
}
func TestSetMethodHandlesInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetMethod("我是谁")
if err == nil {
t.Errorf("Expected error, got nil")
}
}
func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("POST")
if req.Method() != "POST" {
t.Errorf("Expected 'POST', got %v", req.Method())
}
}
func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("你是谁")
if req.Method() != "GET" {
t.Errorf("Expected '', got %v", req.Method())
}
}
func TestUriReturnsCorrectValue(t *testing.T) {
req := NewReq("https://example.com")
if req.Uri() != "https://example.com" {
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
}
}
func TestSetUriHandlesValidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetUri("https://newexample.com")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if req.Uri() != "https://newexample.com" {
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
}
}
func TestSetUriHandlesInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetUri("://invalidurl")
if err == nil {
t.Errorf("Expected error, got nil")
}
}
func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) {
req := NewReq("https://example.com")
req.SetUriNoError("https://newexample.com")
if req.Uri() != "https://newexample.com" {
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
}
}
func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
req.SetUriNoError("://invalidurl")
if req.Uri() != "https://example.com" {
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
}
}
type postmanReply struct {
Args struct {
} `json:"args"`
Form map[string]string `json:"form"`
Headers map[string]string `json:"headers"`
Url string `json:"url"`
}
func TestGet(t *testing.T) {
var reply postmanReply
resp, err := NewReq("https://postman-echo.com/get").
AddHeader("hello", "nononmo").
SetAutoCalcContentLength(true).
Do()
if err != nil {
t.Error(err)
}
err = resp.Body().Unmarshal(&reply)
if err != nil {
t.Error(err)
}
fmt.Println(resp.Body().String())
fmt.Println(reply.Headers)
fmt.Println(resp.Cookies())
}
type testData struct {
name string
args *Request
want func(*Response) error
wantErr bool
}
func headerTestData() []testData {
return []testData{
{
name: "addHeader",
args: NewReq("https://postman-echo.com/get").
AddHeader("b612", "test-data").
AddHeader("b612", "test-header").
AddSimpleCookie("b612", "test-cookie").
SetHeader("User-Agent", "starnet test"),
want: func(resp *Response) error {
//fmt.Println(resp.Body().String())
if resp == nil {
return fmt.Errorf("response is nil")
}
if resp.StatusCode != 200 {
return fmt.Errorf("status code is %d", resp.StatusCode)
}
var reply postmanReply
err := resp.Body().Unmarshal(&reply)
if err != nil {
return err
}
if reply.Headers["b612"] != "test-data, test-header" {
return fmt.Errorf("header not found")
}
if reply.Headers["user-agent"] != "starnet test" {
return fmt.Errorf("user-agent not found")
}
if reply.Headers["cookie"] != "b612=test-cookie" {
return fmt.Errorf("cookie not found")
}
return nil
},
wantErr: false,
},
{
name: "postForm",
args: NewSimpleRequest("https://postman-echo.com/post", "POST").
AddHeader("b612", "test-data").
AddHeader("b612", "test-header").
AddSimpleCookie("b612", "test-cookie").
SetHeader("User-Agent", "starnet test").
//SetHeader("Content-Type", "application/x-www-form-urlencoded").
AddFormData("hello", "world").
AddFormData("hello2", "world2").
SetMethodNoError("POST"),
want: func(resp *Response) error {
//fmt.Println(resp.Body().String())
if resp == nil {
return fmt.Errorf("response is nil")
}
if resp.StatusCode != 200 {
return fmt.Errorf("status code is %d", resp.StatusCode)
}
var reply postmanReply
err := resp.Body().Unmarshal(&reply)
if err != nil {
return err
}
if reply.Headers["b612"] != "test-data, test-header" {
return fmt.Errorf("header not found")
}
if reply.Headers["user-agent"] != "starnet test" {
return fmt.Errorf("user-agent not found")
}
if reply.Headers["cookie"] != "b612=test-cookie" {
return fmt.Errorf("cookie not found")
}
if reply.Form["hello"] != "world" {
return fmt.Errorf("form data not found")
}
if reply.Form["hello2"] != "world2" {
return fmt.Errorf("form data not found")
}
return nil
},
wantErr: false,
},
}
}
func TestCurl(t *testing.T) {
for _, tt := range headerTestData() {
t.Run(tt.name, func(t *testing.T) {
got, err := Curl(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil {
if err := tt.want(got); err != nil {
t.Errorf("Curl() = %v", err)
}
}
})
}
}

@ -1,5 +1,3 @@
module b612.me/starnet module b612.me/starnet
go 1.16 go 1.16
require b612.me/stario v0.0.5

@ -1,13 +0,0 @@
b612.me/stario v0.0.5 h1:Q1OGF+8eOoK49zMzkyh80GWaMuknhey6+PWJJL9ZuNo=
b612.me/stario v0.0.5/go.mod h1:or4ssWcxQSjMeu+hRKEgtp0X517b3zdlEOAms8Qscvw=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

@ -0,0 +1,120 @@
package starnet
import "strings"
var isTokenTable = [127]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
func IsTokenRune(r rune) bool {
i := int(r)
return i < len(isTokenTable) && isTokenTable[i]
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
func isNotToken(r rune) bool {
return !IsTokenRune(r)
}
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
// removeEmptyPort strips the empty port in ":port" to ""
// as mandated by RFC 3986 Section 6.2.3.
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}

@ -33,6 +33,7 @@ func getICMP(seq uint16) ICMP {
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) { func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
var res PingResult var res PingResult
res.RemoteIP = destAddr.String()
conn, err := net.DialIP("ip:icmp", nil, destAddr) conn, err := net.DialIP("ip:icmp", nil, destAddr)
if err != nil { if err != nil {
return res, err return res, err
@ -84,6 +85,7 @@ func checkSum(data []byte) uint16 {
type PingResult struct { type PingResult struct {
Duration time.Duration Duration time.Duration
RecvCount int RecvCount int
RemoteIP string
} }
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
@ -95,3 +97,14 @@ func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
icmp := getICMP(uint16(seq)) icmp := getICMP(uint16(seq))
return sendICMPRequest(icmp, ipAddr, timeout) return sendICMPRequest(icmp, ipAddr, timeout)
} }
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
for i := 0; i < retryLimit; i++ {
_, err := Ping(ip, 29, timeout)
if err != nil {
continue
}
return true
}
return false
}

@ -7,5 +7,9 @@ import (
) )
func Test_Ping(t *testing.T) { func Test_Ping(t *testing.T) {
fmt.Println(Ping("baidu.com", 0, time.Second*2)) fmt.Println(Ping("baidu.com", 29, time.Second*2))
fmt.Println(Ping("www.b612.me", 29, time.Second*2))
fmt.Println(IsIpPingable("baidu.com", time.Second*2, 3))
fmt.Println(IsIpPingable("www.b612.me", time.Second*2, 3))
} }

317
que.go

@ -1,317 +0,0 @@
package starnet
import (
"bytes"
"context"
"encoding/binary"
"errors"
"os"
"sync"
"time"
)
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
count int64
Encode bool
Reserve uint16
Msgid uint16
MsgPool chan MsgQueue
UnFinMsg sync.Map
LastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64) *StarQueue {
var que StarQueue
que.Encode = false
que.count = count
que.MsgPool = make(chan MsgQueue, count)
if ctx == nil {
que.ctx, que.cancel = context.WithCancel(context.Background())
} else {
que.ctx, que.cancel = context.WithCancel(ctx)
}
que.duration = 0
return &que
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (que *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
que.Msgid++
if que.Encode {
src = que.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (que *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
que.Msgid++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (que *StarQueue) push2list(msg MsgQueue) {
que.MsgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return que.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (que *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := que.UnFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
que.UnFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return que.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
que.UnFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//que.restoreMu.Lock()
que.push2list(storeMsg)
//que.restoreMu.Unlock()
que.UnFinMsg.Delete(conn)
return que.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
que.UnFinMsg.Delete(conn)
return que.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
que.push2list(storeMsg)
que.UnFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return que.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
que.UnFinMsg.Store(conn, &lastMsg)
return que.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (que *StarQueue) Restore() (MsgQueue, error) {
if que.duration.Seconds() == 0 {
que.duration = 86400 * time.Second
}
for {
select {
case <-que.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(que.duration):
if que.duration != 0 {
return MsgQueue{}, os.ErrDeadlineExceeded
}
case data, ok := <-que.MsgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
//兼容性修改
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
return que.Restore()
}
// Stop 立即停止Restore
func (que *StarQueue) Stop() {
que.cancel()
}
// RestoreDuration Restore最大超时时间
func (que *StarQueue) RestoreDuration(tm time.Duration) {
que.duration = tm
}
func (que *StarQueue) RestoreChan() <-chan MsgQueue {
return que.MsgPool
}

@ -1,42 +0,0 @@
package starnet
import (
"fmt"
"testing"
"time"
)
func Test_QueSpeed(t *testing.T) {
que := NewQueueWithCount(0)
stop := make(chan struct{}, 1)
que.RestoreDuration(time.Second * 10)
var count int64
go func() {
for {
select {
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select {
case <-stoped:
fmt.Println(count, cp)
stop <- struct{}{}
return
default:
que.ParseMessage(data, "lala")
cp++
}
}
}
Loading…
Cancel
Save