Compare commits

..

18 Commits

Author SHA1 Message Date
b90c59d6e7
修改版本号 2025-08-21 21:40:29 +08:00
4e154cc17b
update benchmark 2025-08-21 21:37:21 +08:00
67b0025f9c
更新content-length的默认处理方式 2025-08-21 19:17:19 +08:00
c4fa62536a
为client新增部分函数 2025-08-21 15:32:19 +08:00
260ceb90ed
重构http Client部分 2025-08-21 15:02:02 +08:00
d260181adf
update 2025-08-15 15:07:51 +08:00
e3b7369e12
bug fix:nil pointer error 2025-08-13 10:16:08 +08:00
4e17fee681
bug fix 2025-07-14 18:38:31 +08:00
a8eed30db5
add http client control 2025-07-14 18:23:14 +08:00
c1eaf43058 update 2025-06-17 12:36:57 +08:00
9f5aca124d update 2025-06-17 12:09:12 +08:00
54958724e7 bug fix 2025-06-13 17:16:38 +08:00
7a17672149 update tls sniffer 2025-06-12 16:50:47 +08:00
44b807d3d1 update 2025-06-06 15:43:38 +08:00
0d847462b3 bug fix:nil pointer 2025-04-28 13:19:45 +08:00
deed4207ea bug fix 2024-08-30 23:44:49 +08:00
f6363fed07 move starqueue from starnet to stario 2024-08-18 17:18:52 +08:00
1de78f2f06 rewrite curl.go 2024-08-08 22:03:10 +08:00
12 changed files with 3572 additions and 805 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.idea

2156
curl.go

File diff suppressed because it is too large Load Diff

198
curl_default.go Normal file
View File

@ -0,0 +1,198 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
)
const (
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
HEADER_FORM_DATA = `multipart/form-data`
HEADER_JSON = `application/json`
HEADER_PLAIN = `text/plain`
)
var (
DefaultDialTimeout = 5 * time.Second
DefaultTimeout = 10 * time.Second
DefaultFetchRespBody = false
DefaultHttpClient = NewHttpClientNoErr()
)
func UrlEncodeRaw(str string) string {
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
return strs
}
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
func BuildQuery(queryData map[string]string) string {
query := url.Values{}
for k, v := range queryData {
query.Add(k, v)
}
return query.Encode()
}
// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string,
// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests,
// where the server expects the request body to be URL-encoded form data.
//
// Parameters:
// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request.
//
// Returns:
// A byte slice representing the URL-encoded form data.
func BuildPostForm(queryMap map[string]string) []byte {
return []byte(BuildQuery(queryMap))
}
func Get(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "GET", opts...).Do()
}
func Post(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "POST", opts...).Do()
}
func Options(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "OPTIONS", opts...).Do()
}
func Put(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PUT", opts...).Do()
}
func Delete(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "DELETE", opts...).Do()
}
func Head(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "HEAD", opts...).Do()
}
func Patch(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PATCH", opts...).Do()
}
func Trace(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "TRACE", opts...).Do()
}
func Connect(uri string, opts ...RequestOpt) (*Response, error) {
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "CONNECT", opts...).Do()
}
func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func DefaultDialFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
var lastErr error
var addrs []string
if dialFn, ok := ctx.Value("dialFunc").(func(context.Context, string, string) (net.Conn, error)); ok {
if dialFn != nil {
return dialFn(ctx, netType, addr)
}
}
customIP, ok := ctx.Value("customIP").([]string)
if !ok {
customIP = nil
}
dialTimeout, ok := ctx.Value("dialTimeout").(time.Duration)
if !ok {
dialTimeout = DefaultDialTimeout
}
timeout, ok := ctx.Value("timeout").(time.Duration)
if !ok {
timeout = DefaultTimeout
}
lookUpIPfn, ok := ctx.Value("lookUpIP").(func(context.Context, string) ([]net.IPAddr, error))
if !ok {
lookUpIPfn = net.DefaultResolver.LookupIPAddr
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
proxy, ok := ctx.Value("proxy").(string)
if !ok {
proxy = ""
}
if proxy == "" && len(customIP) > 0 {
for _, v := range customIP {
ipAddr := net.ParseIP(v)
if ipAddr == nil {
return nil, fmt.Errorf("invalid custom ip: %s", customIP)
}
tmpAddr := net.JoinHostPort(v, port)
addrs = append(addrs, tmpAddr)
}
} else {
ipLists, err := lookUpIPfn(ctx, host)
if err != nil {
return nil, err
}
for _, v := range ipLists {
tmpAddr := net.JoinHostPort(v.String(), port)
addrs = append(addrs, tmpAddr)
}
}
for _, addr := range addrs {
c, err := net.DialTimeout(netType, addr, dialTimeout)
if err != nil {
lastErr = err
continue
}
if timeout != 0 {
err = c.SetDeadline(time.Now().Add(timeout))
}
return c, nil
}
return nil, lastErr
}
func DefaultDialTlsFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
conn, err := DefaultDialFunc(ctx, netType, addr)
if err != nil {
return nil, err
}
tlsConfig, ok := ctx.Value("tlsConfig").(*tls.Config)
if !ok || tlsConfig == nil {
return nil, fmt.Errorf("tlsConfig is not set in context")
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, fmt.Errorf("tls handshake failed: %w", err)
}
return tlsConn, nil
}
func DefaultProxyURL() func(*http.Request) (*url.URL, error) {
return func(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
proxyURL, ok := req.Context().Value("proxy").(string)
if !ok || proxyURL == "" {
return nil, nil
}
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
}
return parsedURL, nil
}
}

698
curl_test.go Normal file
View File

@ -0,0 +1,698 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestUrlEncodeRaw(t *testing.T) {
input := "hello world!@#$%^&*()_+-=~`"
expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60"
result := UrlEncodeRaw(input)
if result != expected {
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected)
}
}
func TestUrlEncode(t *testing.T) {
input := "hello world!@#$%^&*()_+-=~`"
expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60`
result := UrlEncode(input)
if result != expected {
t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected)
}
}
func TestUrlDecode(t *testing.T) {
input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60"
expected := "hello world!@#$%^&*()_+-=~`"
result, err := UrlDecode(input)
if err != nil {
t.Errorf("UrlDecode(%q) returned error: %v", input, err)
}
if result != expected {
t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected)
}
// Test for error case
invalidInput := "%zz"
_, err = UrlDecode(invalidInput)
if err == nil {
t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput)
}
}
func TestBuildPostForm_WithValidInput(t *testing.T) {
input := map[string]string{
"key1": "value1",
"key2": "value2",
}
expected := []byte("key1=value1&key2=value2")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestBuildPostForm_WithEmptyInput(t *testing.T) {
input := map[string]string{}
expected := []byte("")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestBuildPostForm_WithNilInput(t *testing.T) {
var input map[string]string
expected := []byte("")
result := BuildPostForm(input)
if string(result) != string(expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
}
}
func TestGetRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Get(server.URL, WithSkipTLSVerify(true), WithHeader("hello", "world"), WithUserAgent("hello world"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestPostRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
t.Errorf("Expected 'POST', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Post(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestOptionsRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodOptions {
t.Errorf("Expected 'OPTIONS', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Options(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestPutRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPut {
t.Errorf("Expected 'PUT', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Put(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestDeleteRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodDelete {
t.Errorf("Expected 'DELETE', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Delete(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestHeadRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodHead {
t.Errorf("Expected 'HEAD', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Head(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body == "OK" {
t.Errorf("Expected , got %v", body)
}
}
func TestPatchRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPatch {
t.Errorf("Expected 'PATCH', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Patch(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestTraceRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodTrace {
t.Errorf("Expected 'TRACE', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Trace(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestConnectRequestWithValidInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
t.Errorf("Expected 'CONNECT', got %v", req.Method)
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Connect(server.URL)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestMethodReturnsCorrectValue(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("GET")
if req.Method() != "GET" {
t.Errorf("Expected 'GET', got %v", req.Method())
}
}
func TestSetMethodHandlesInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetMethod("我是谁")
if err == nil {
t.Errorf("Expected error, got nil")
}
}
func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("POST")
if req.Method() != "POST" {
t.Errorf("Expected 'POST', got %v", req.Method())
}
}
func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
req.SetMethodNoError("你是谁")
if req.Method() != "GET" {
t.Errorf("Expected '', got %v", req.Method())
}
}
func TestUriReturnsCorrectValue(t *testing.T) {
req := NewReq("https://example.com")
if req.Uri() != "https://example.com" {
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
}
}
func TestSetUriHandlesValidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetUri("https://newexample.com")
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if req.Uri() != "https://newexample.com" {
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
}
}
func TestSetUriHandlesInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
err := req.SetUri("://invalidurl")
if err == nil {
t.Errorf("Expected error, got nil")
}
}
func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) {
req := NewReq("https://example.com")
req.SetUriNoError("https://newexample.com")
if req.Uri() != "https://newexample.com" {
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
}
}
func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) {
req := NewReq("https://example.com")
req.SetUriNoError("://invalidurl")
if req.Uri() != "https://example.com" {
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
}
}
type postmanReply struct {
Args struct {
} `json:"args"`
Form map[string]string `json:"form"`
Headers map[string]string `json:"headers"`
Url string `json:"url"`
}
func TestGet(t *testing.T) {
var reply postmanReply
resp, err := NewReq("https://postman-echo.com/get").
AddHeader("hello", "nononmo").
SetAutoCalcContentLengthNoError(true).Do()
if err != nil {
t.Error(err)
}
fmt.Println(resp.Proto)
err = resp.Body().Unmarshal(&reply)
if err != nil {
t.Error(err)
}
fmt.Println(resp.Body().String())
fmt.Println(reply.Headers)
fmt.Println(resp.Cookies())
}
type testData struct {
name string
args *Request
want func(*Response) error
wantErr bool
}
func headerTestData() []testData {
return []testData{
{
name: "addHeader",
args: NewReq("https://postman-echo.com/get").
AddHeader("b612", "test-data").
AddHeader("b612", "test-header").
AddSimpleCookie("b612", "test-cookie").
SetHeader("User-Agent", "starnet test"),
want: func(resp *Response) error {
//fmt.Println(resp.Body().String())
if resp == nil {
return fmt.Errorf("response is nil")
}
if resp.StatusCode != 200 {
return fmt.Errorf("status code is %d", resp.StatusCode)
}
var reply postmanReply
err := resp.Body().Unmarshal(&reply)
if err != nil {
return err
}
if reply.Headers["b612"] != "test-data, test-header" {
return fmt.Errorf("header not found")
}
if reply.Headers["user-agent"] != "starnet test" {
return fmt.Errorf("user-agent not found")
}
if reply.Headers["cookie"] != "b612=test-cookie" {
return fmt.Errorf("cookie not found")
}
return nil
},
wantErr: false,
},
{
name: "postForm",
args: NewSimpleRequest("https://postman-echo.com/post", "POST").
AddHeader("b612", "test-data").
AddHeader("b612", "test-header").
AddSimpleCookie("b612", "test-cookie").
SetHeader("User-Agent", "starnet test").
//SetHeader("Content-Type", "application/x-www-form-urlencoded").
AddFormData("hello", "world").
AddFormData("hello2", "world2").
SetMethodNoError("POST"),
want: func(resp *Response) error {
//fmt.Println(resp.Body().String())
if resp == nil {
return fmt.Errorf("response is nil")
}
if resp.StatusCode != 200 {
return fmt.Errorf("status code is %d", resp.StatusCode)
}
var reply postmanReply
err := resp.Body().Unmarshal(&reply)
if err != nil {
return err
}
if reply.Headers["b612"] != "test-data, test-header" {
return fmt.Errorf("header not found")
}
if reply.Headers["user-agent"] != "starnet test" {
return fmt.Errorf("user-agent not found")
}
if reply.Headers["cookie"] != "b612=test-cookie" {
return fmt.Errorf("cookie not found")
}
if reply.Form["hello"] != "world" {
return fmt.Errorf("form data not found")
}
if reply.Form["hello2"] != "world2" {
return fmt.Errorf("form data not found")
}
return nil
},
wantErr: false,
},
}
}
func TestCurl(t *testing.T) {
for _, tt := range headerTestData() {
t.Run(tt.name, func(t *testing.T) {
got, err := Curl(tt.args)
if (err != nil) != tt.wantErr {
t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil {
if err := tt.want(got); err != nil {
t.Errorf("Curl() = %v", err)
}
}
})
}
}
func TestReqClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("hello") != "world" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("hello world failed"))
return
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
resp, err := req.Do()
if err != nil {
t.Error(err)
}
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
req = req.Clone()
req.AddHeader("ok", "good")
resp, err = req.Do()
if err != nil {
t.Error(err)
}
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
}
func TestUploadFile(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("hello") != "world" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("hello world failed"))
return
}
files, header, err := req.FormFile("666")
if err == nil {
fmt.Println(header.Filename)
fmt.Println(header.Size)
fmt.Println(files.Close())
}
files, header, err = req.FormFile("777")
if err == nil {
fmt.Println(header.Filename)
fmt.Println(header.Size)
fmt.Println(files.Close())
}
files, header, err = req.FormFile("888")
if err == nil {
fmt.Println(header.Filename)
fmt.Println(header.Size)
fmt.Println(files.Close())
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
req.AddFileWithName("666", "./curl.go", "curl.go")
req.AddFile("777", "./go.mod")
req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html")
resp, err := req.Do()
if err != nil {
t.Error(err)
}
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
req = req.Clone()
req.AddHeader("ok", "good")
resp, err = req.Do()
if err != nil {
t.Error(err)
}
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
}
func TestTlsConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get("hello") != "world" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("hello world failed"))
return
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
client, err := NewHttpClient(WithSkipTLSVerify(false))
if err != nil {
t.Error(err)
}
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world"))
//SetClientSkipVerify(client, true)
//req.SetDoRawClient(false)
//req.SetDoRawTransport(false)
req.SetSkipTLSVerify(true)
req.SetProxy("http://127.0.0.1:29992")
resp, err := req.Do()
if err != nil {
t.Error(err)
}
fmt.Println(resp.Proto)
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
req = req.Clone()
req.AddHeader("ok", "good")
resp, err = req.Do()
if err != nil {
t.Error(err)
}
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
req = req.Clone()
req.SetSkipTLSVerify(false)
resp, err = req.Do()
if err == nil {
t.Error(err)
}
}
func TestHttpPostAndChunked(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodPost {
t.Errorf("Expected 'POST', got %v", req.Method)
}
buf := make([]byte, 1024)
n, _ := req.Body.Read(buf)
if string(buf[:n]) != "hello world" {
t.Errorf("Expected body to be 'hello world', got %s", string(buf[:n]))
}
if req.Header.Get("chunked") == "true" {
if req.TransferEncoding[0] != "chunked" {
t.Errorf("Expected Transfer-Encoding to be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
}
} else {
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
t.Errorf("Expected Transfer-Encoding to not be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
}
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
resp, err := Post(server.URL, WithBytes([]byte("hello world")), WithContentLength(-1), WithHeader("Content-Type", "text/plain"),
WithHeader("chunked", "true"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
resp.Close()
resp, err = Post(server.URL, WithBytes([]byte("hello world")), WithHeader("Content-Type", "text/plain"),
WithHeader("chunked", "false"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
defer resp.Close()
body = resp.Body().String()
if body != "OK" {
t.Errorf("Expected OK, got %v", body)
}
}
func TestWithTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(time.Second * 30)
rw.Write([]byte(`OK`))
}))
funcList := []func(string, ...RequestOpt) (*Response, error){
Get,
Post,
Put,
Delete,
Options,
Patch,
Head,
Trace,
Connect,
}
defer server.Close()
for i := 1; i < 30; i++ {
go func(i int) {
old := time.Now()
fn := funcList[i%len(funcList)]
resp, err := fn(server.URL, WithTimeout(time.Second*time.Duration(i)))
if time.Since(old) > time.Second*time.Duration(i+2) || time.Since(old) < time.Second*time.Duration(i) {
t.Errorf("timeout not work")
}
fmt.Println(time.Since(old))
if err == nil {
t.Error(err)
resp.CloseAll()
} else {
fmt.Println(err)
}
}(i)
}
resp, err := Get(server.URL, WithTimeout(time.Second*60))
if err != nil {
t.Error(err)
} else {
fmt.Println(resp.Body().String())
if resp.StatusCode != 200 {
resp.CloseAll()
t.Errorf("status code is %d", resp.StatusCode)
}
resp.CloseAll()
}
}

165
curl_transport.go Normal file
View File

@ -0,0 +1,165 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"reflect"
)
type Client struct {
*http.Client
}
// NewHttpClient creates a new http.Client with the specified options.
func NewHttpClient(opts ...RequestOpt) (Client, error) {
req, err := newRequest(context.Background(), "", "", opts...)
if err != nil {
return Client{}, err
}
defer func() {
req = nil
}()
cl, err := req.HttpClient()
return Client{
Client: cl,
}, err
}
func NewHttpClientNoErr(opts ...RequestOpt) Client {
c, _ := NewHttpClient(opts...)
return c
}
func NewClientFromHttpClient(httpClient *http.Client) (Client, error) {
if httpClient == nil {
return Client{}, fmt.Errorf("httpClient cannot be nil")
}
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
httpClient.Transport = &Transport{
base: t,
}
default:
return Client{}, fmt.Errorf("unsupported transport type: %T", t)
}
}
return Client{
Client: httpClient,
}, nil
}
func NewClientFromHttpClientNoError(httpClient *http.Client) Client {
return Client{Client: httpClient}
}
// DisableRedirect returns whether the request will disable HTTP redirects.
// if true, the request will not follow redirects automatically.
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
// you will get the original response with the redirect status code and Location header.
func (c Client) DisableRedirect() bool {
return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer()
}
// SetDisableRedirect sets whether the request will disable HTTP redirects.
// if true, the request will not follow redirects automatically.
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
// you will get the original response with the redirect status code and Location header.
func (c Client) SetDisableRedirect(disableRedirect bool) {
if disableRedirect {
c.Client.CheckRedirect = DefaultCheckRedirectFunc
}
}
func (c Client) SetDefaultSkipTLSVerify(skip bool) {
if c.Client.Transport == nil {
c.Client.Transport = &Transport{
base: &http.Transport{},
}
}
if transport, ok := c.Client.Transport.(*Transport); ok {
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.InsecureSkipVerify = skip
}
}
func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) {
if c.Client.Transport == nil {
c.Client.Transport = &Transport{
base: &http.Transport{},
}
}
if transport, ok := c.Client.Transport.(*Transport); ok {
transport.base.TLSClientConfig = tlsConfig
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
transport.TLSClientConfig = tlsConfig
}
}
func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
if c.Client == nil {
return nil, fmt.Errorf("http client is nil")
}
req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...)
return req, err
}
func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
if c.Client == nil {
return nil, fmt.Errorf("http client is nil")
}
req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...)
return req, err
}
func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
req, _ := c.NewRequest(url, method, opts...)
return req
}
func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, _ := c.NewRequestContext(ctx, url, method, opts...)
return req
}
type Transport struct {
base *http.Transport
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.base == nil {
t.base = &http.Transport{}
}
transport, ok := req.Context().Value("transport").(*http.Transport)
if ok && transport != nil {
return transport.RoundTrip(req)
}
proxy, ok := req.Context().Value("proxy").(string)
if ok && proxy != "" {
tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config)
if ok && tlsConfig != nil {
tmpTransport := t.base.Clone()
tmpTransport.TLSClientConfig = tlsConfig
return tmpTransport.RoundTrip(req)
}
}
return t.base.RoundTrip(req)
}

198
curlbench_test.go Normal file
View File

@ -0,0 +1,198 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"testing"
)
// BenchmarkGetRequest 测试单个 GET 请求的性能
func BenchmarkGetRequest(b *testing.B) {
// 创建测试服务器
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(`OK`))
}))
defer server.Close()
// 重置计时器,排除设置代码的影响
b.ResetTimer()
// 报告内存分配情况
b.ReportAllocs()
// 运行基准测试
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
b.Errorf("Expected OK, got %v", body)
}
}
}
// BenchmarkGetRequestWithHeaders 测试带请求头的 GET 请求性能
func BenchmarkGetRequestWithHeaders(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// 验证请求头
if req.Header.Get("hello") != "world" {
rw.WriteHeader(http.StatusBadRequest)
return
}
rw.Write([]byte(`OK`))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL,
WithSkipTLSVerify(true),
WithHeader("hello", "world"),
WithUserAgent("hello world"))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
b.Errorf("Expected OK, got %v", body)
}
}
}
// BenchmarkPostRequest 测试 POST 请求的性能
func BenchmarkPostRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// 读取并返回请求体
body := make([]byte, req.ContentLength)
req.Body.Read(body)
rw.Write(body)
}))
defer server.Close()
testData := "This is a test payload for POST request"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL,
WithSkipTLSVerify(true),
WithBytes([]byte(testData)),
WithContentType("text/plain"))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != testData {
b.Errorf("Expected %s, got %v", testData, body)
}
}
}
// BenchmarkConcurrentRequests 测试并发请求性能
func BenchmarkConcurrentRequests(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(`OK`))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
// 运行并发基准测试
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
b.Errorf("Expected OK, got %v", body)
}
}
})
}
// BenchmarkMemoryUsage 专门测试内存使用情况
func BenchmarkMemoryUsage(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte(`OK`))
}))
defer server.Close()
// 禁用默认的测试时间,只关注内存分配
b.ReportAllocs()
var memStatsStart, memStatsEnd runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&memStatsStart)
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().String()
if body != "OK" {
b.Errorf("Expected OK, got %v", body)
}
}
runtime.GC()
runtime.ReadMemStats(&memStatsEnd)
// 计算每次操作的平均内存分配
allocsPerOp := float64(memStatsEnd.Mallocs-memStatsStart.Mallocs) / float64(b.N)
bytesPerOp := float64(memStatsEnd.TotalAlloc-memStatsStart.TotalAlloc) / float64(b.N)
b.ReportMetric(allocsPerOp, "allocs/op")
b.ReportMetric(bytesPerOp, "bytes/op")
}
// BenchmarkDifferentResponseSizes 测试不同响应大小的性能
func BenchmarkDifferentResponseSizes(b *testing.B) {
// 测试不同大小的响应
responseSizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
for _, size := range responseSizes {
// 生成指定大小的响应数据
responseData := make([]byte, size)
for i := 0; i < size; i++ {
responseData[i] = 'A'
}
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Write(responseData)
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
b.Errorf("Unexpected error: %v", err)
}
body := resp.Body().Bytes()
if len(body) != size {
b.Errorf("Expected size %d, got %d", size, len(body))
}
}
})
}
}

2
go.mod
View File

@ -1,5 +1,3 @@
module b612.me/starnet
go 1.16
require b612.me/stario v0.0.9

47
go.sum
View File

@ -1,47 +0,0 @@
b612.me/stario v0.0.9 h1:bFDlejUJMwZ12a09snZJspQsOlkqpDAl9qKPEYOGWCk=
b612.me/stario v0.0.9/go.mod h1:x4D/x8zA5SC0pj/uJAi4FyG5p4j5UZoMEZfvuRR6VNw=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

120
httpguts.go Normal file
View File

@ -0,0 +1,120 @@
package starnet
import "strings"
var isTokenTable = [127]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
func IsTokenRune(r rune) bool {
i := int(r)
return i < len(isTokenTable) && isTokenTable[i]
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
func isNotToken(r rune) bool {
return !IsTokenRune(r)
}
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
// removeEmptyPort strips the empty port in ":port" to ""
// as mandated by RFC 3986 Section 6.2.3.
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}

325
que.go
View File

@ -1,325 +0,0 @@
package starnet
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"sync"
"time"
)
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
maxLength uint32
count int64
Encode bool
msgID uint16
msgPool chan MsgQueue
unFinMsg sync.Map
lastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
var q StarQueue
q.Encode = false
q.count = count
q.maxLength = maxMsgLength
q.msgPool = make(chan MsgQueue, count)
if ctx == nil {
q.ctx, q.cancel = context.WithCancel(context.Background())
} else {
q.ctx, q.cancel = context.WithCancel(ctx)
}
q.duration = 0
return &q
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count, 0)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (q *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
q.msgID++
if q.Encode {
src = q.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (q *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
q.msgID++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (q *StarQueue) push2list(msg MsgQueue) {
q.msgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := q.unFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
q.unFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return q.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
q.unFinMsg.Delete(conn)
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
}
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
q.unFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
q.unFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//q.restoreMu.Lock()
q.push2list(storeMsg)
//q.restoreMu.Unlock()
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
q.push2list(storeMsg)
q.unFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return q.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
q.unFinMsg.Store(conn, &lastMsg)
return q.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (q *StarQueue) Restore() (MsgQueue, error) {
if q.duration.Seconds() == 0 {
q.duration = 86400 * time.Second
}
for {
select {
case <-q.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(q.duration):
if q.duration != 0 {
return MsgQueue{}, ErrDeadlineExceeded
}
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
// 兼容性修改
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
return q.Restore()
}
// Stop 立即停止Restore
func (q *StarQueue) Stop() {
q.cancel()
}
// RestoreDuration Restore最大超时时间
func (q *StarQueue) RestoreDuration(tm time.Duration) {
q.duration = tm
}
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
return q.msgPool
}

View File

@ -1,42 +0,0 @@
package starnet
import (
"fmt"
"testing"
"time"
)
func Test_QueSpeed(t *testing.T) {
que := NewQueueWithCount(0)
stop := make(chan struct{}, 1)
que.RestoreDuration(time.Second * 10)
var count int64
go func() {
for {
select {
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select {
case <-stoped:
fmt.Println(count, cp)
stop <- struct{}{}
return
default:
que.ParseMessage(data, "lala")
cp++
}
}
}

401
tlssniffer.go Normal file
View File

@ -0,0 +1,401 @@
package starnet
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"sync"
"time"
)
type myConn struct {
reader io.Reader
conn net.Conn
isReadOnly bool
multiReader io.Reader
}
func (c *myConn) Read(p []byte) (int, error) {
if c.isReadOnly {
return c.reader.Read(p)
}
if c.multiReader == nil {
c.multiReader = io.MultiReader(c.reader, c.conn)
}
return c.multiReader.Read(p)
}
func (c *myConn) Write(p []byte) (int, error) {
if c.isReadOnly {
return 0, io.ErrClosedPipe
}
return c.conn.Write(p)
}
func (c *myConn) Close() error {
if c.isReadOnly {
return nil
}
return c.conn.Close()
}
func (c *myConn) LocalAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.LocalAddr()
}
func (c *myConn) RemoteAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.RemoteAddr()
}
func (c *myConn) SetDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetDeadline(t)
}
func (c *myConn) SetReadDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetReadDeadline(t)
}
func (c *myConn) SetWriteDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetWriteDeadline(t)
}
type Listener struct {
net.Listener
cfg *tls.Config
getConfigForClient func(hostname string) *tls.Config
allowNonTls bool
}
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
return l.getConfigForClient
}
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
l.getConfigForClient = getConfigForClient
}
func Listen(network, address string) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: listener}, nil
}
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
listener, err := liscfg.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: tlsConfig,
allowNonTls: allowNonTls,
}, nil
}
func (l *Listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
tlsCfg: l.cfg,
getConfigForClient: l.getConfigForClient,
allowNonTls: l.allowNonTls,
}, nil
}
type Conn struct {
net.Conn
once sync.Once
initErr error
isTLS bool
tlsCfg *tls.Config
tlsConn *tls.Conn
buffer *bytes.Buffer
noTlsReader io.Reader
isOriginal bool
getConfigForClient func(hostname string) *tls.Config
hostname string
allowNonTls bool
}
func (c *Conn) Hostname() string {
if c.hostname != "" {
return c.hostname
}
if c.isTLS && c.tlsConn != nil {
if c.tlsConn.ConnectionState().ServerName != "" {
c.hostname = c.tlsConn.ConnectionState().ServerName
return c.hostname
}
}
return ""
}
func (c *Conn) IsTLS() bool {
return c.isTLS
}
func (c *Conn) TlsConn() *tls.Conn {
return c.tlsConn
}
func (c *Conn) isTLSConnection() (bool, error) {
if c.getConfigForClient == nil {
peek := make([]byte, 5)
n, err := io.ReadFull(c.Conn, peek)
if err != nil {
return false, err
}
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
c.buffer = bytes.NewBuffer(peek[:n])
return isTLS, nil
}
c.buffer = new(bytes.Buffer)
r := io.TeeReader(c.Conn, c.buffer)
var hello *tls.ClientHelloInfo
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
peek := c.buffer.Bytes()
n := len(peek)
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
if hello == nil {
return isTLS, nil
}
c.hostname = hello.ServerName
if c.hostname == "" {
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
}
return isTLS, nil
}
func (c *Conn) init() {
c.once.Do(func() {
if c.isOriginal {
return
}
if c.tlsCfg != nil {
isTLS, err := c.isTLSConnection()
if err != nil {
c.initErr = err
return
}
c.isTLS = isTLS
}
if c.isTLS {
var cfg = c.tlsCfg
if c.getConfigForClient != nil {
cfg = c.getConfigForClient(c.hostname)
if cfg == nil {
cfg = c.tlsCfg
}
}
c.tlsConn = tls.Server(&myConn{
reader: c.buffer,
conn: c.Conn,
isReadOnly: false,
}, cfg)
} else {
if !c.allowNonTls {
c.initErr = net.ErrClosed
return
}
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
}
})
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.noTlsReader.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.Conn.Write(b)
}
func (c *Conn) Close() error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.Close()
}
return c.Conn.Close()
}
func (c *Conn) SetDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.Conn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.Conn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.Conn.SetWriteDeadline(t)
}
func (c *Conn) TlsConnection() (*tls.Conn, error) {
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS {
return nil, net.ErrClosed
}
return c.tlsConn, nil
}
func (c *Conn) OriginalConn() net.Conn {
return c.Conn
}
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Client(conn, cfg),
isOriginal: true,
}
return c, nil
}
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Server(conn, cfg),
isOriginal: true,
}
c.init()
return c, nil
}
func Dial(network, address string) (*Conn, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
isTLS: false,
tlsCfg: nil,
tlsConn: nil,
noTlsReader: conn,
isOriginal: true,
}, nil
}
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClientTlsConn(conn, tlsConfig)
}