Compare commits
24 Commits
Author | SHA1 | Date | |
---|---|---|---|
b90c59d6e7 | |||
4e154cc17b | |||
67b0025f9c | |||
c4fa62536a | |||
260ceb90ed | |||
d260181adf | |||
e3b7369e12 | |||
4e17fee681 | |||
a8eed30db5 | |||
c1eaf43058 | |||
9f5aca124d | |||
54958724e7 | |||
7a17672149 | |||
44b807d3d1 | |||
0d847462b3 | |||
deed4207ea | |||
f6363fed07 | |||
1de78f2f06 | |||
d0122a9771 | |||
319518d71d | |||
be3df9703e | |||
b92288bbc9 | |||
0805549006 | |||
033272f38a |
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
.idea
|
198
curl_default.go
Normal file
198
curl_default.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
|
||||||
|
HEADER_FORM_DATA = `multipart/form-data`
|
||||||
|
HEADER_JSON = `application/json`
|
||||||
|
HEADER_PLAIN = `text/plain`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
DefaultDialTimeout = 5 * time.Second
|
||||||
|
DefaultTimeout = 10 * time.Second
|
||||||
|
DefaultFetchRespBody = false
|
||||||
|
DefaultHttpClient = NewHttpClientNoErr()
|
||||||
|
)
|
||||||
|
|
||||||
|
func UrlEncodeRaw(str string) string {
|
||||||
|
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
||||||
|
return strs
|
||||||
|
}
|
||||||
|
|
||||||
|
func UrlEncode(str string) string {
|
||||||
|
return url.QueryEscape(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UrlDecode(str string) (string, error) {
|
||||||
|
return url.QueryUnescape(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildQuery(queryData map[string]string) string {
|
||||||
|
query := url.Values{}
|
||||||
|
for k, v := range queryData {
|
||||||
|
query.Add(k, v)
|
||||||
|
}
|
||||||
|
return query.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string,
|
||||||
|
// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests,
|
||||||
|
// where the server expects the request body to be URL-encoded form data.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// A byte slice representing the URL-encoded form data.
|
||||||
|
func BuildPostForm(queryMap map[string]string) []byte {
|
||||||
|
return []byte(BuildQuery(queryMap))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Get(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "GET", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Post(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "POST", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Options(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "OPTIONS", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Put(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PUT", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Delete(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "DELETE", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Head(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "HEAD", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Patch(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PATCH", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Trace(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "TRACE", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Connect(uri string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "CONNECT", opts...).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultDialFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||||
|
var lastErr error
|
||||||
|
var addrs []string
|
||||||
|
if dialFn, ok := ctx.Value("dialFunc").(func(context.Context, string, string) (net.Conn, error)); ok {
|
||||||
|
if dialFn != nil {
|
||||||
|
return dialFn(ctx, netType, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
customIP, ok := ctx.Value("customIP").([]string)
|
||||||
|
if !ok {
|
||||||
|
customIP = nil
|
||||||
|
}
|
||||||
|
dialTimeout, ok := ctx.Value("dialTimeout").(time.Duration)
|
||||||
|
if !ok {
|
||||||
|
dialTimeout = DefaultDialTimeout
|
||||||
|
}
|
||||||
|
timeout, ok := ctx.Value("timeout").(time.Duration)
|
||||||
|
if !ok {
|
||||||
|
timeout = DefaultTimeout
|
||||||
|
}
|
||||||
|
lookUpIPfn, ok := ctx.Value("lookUpIP").(func(context.Context, string) ([]net.IPAddr, error))
|
||||||
|
if !ok {
|
||||||
|
lookUpIPfn = net.DefaultResolver.LookupIPAddr
|
||||||
|
}
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
proxy, ok := ctx.Value("proxy").(string)
|
||||||
|
if !ok {
|
||||||
|
proxy = ""
|
||||||
|
}
|
||||||
|
if proxy == "" && len(customIP) > 0 {
|
||||||
|
for _, v := range customIP {
|
||||||
|
ipAddr := net.ParseIP(v)
|
||||||
|
if ipAddr == nil {
|
||||||
|
return nil, fmt.Errorf("invalid custom ip: %s", customIP)
|
||||||
|
}
|
||||||
|
tmpAddr := net.JoinHostPort(v, port)
|
||||||
|
addrs = append(addrs, tmpAddr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ipLists, err := lookUpIPfn(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, v := range ipLists {
|
||||||
|
tmpAddr := net.JoinHostPort(v.String(), port)
|
||||||
|
addrs = append(addrs, tmpAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
c, err := net.DialTimeout(netType, addr, dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if timeout != 0 {
|
||||||
|
err = c.SetDeadline(time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultDialTlsFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||||
|
conn, err := DefaultDialFunc(ctx, netType, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig, ok := ctx.Value("tlsConfig").(*tls.Config)
|
||||||
|
if !ok || tlsConfig == nil {
|
||||||
|
return nil, fmt.Errorf("tlsConfig is not set in context")
|
||||||
|
}
|
||||||
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
return nil, fmt.Errorf("tls handshake failed: %w", err)
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultProxyURL() func(*http.Request) (*url.URL, error) {
|
||||||
|
return func(req *http.Request) (*url.URL, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("request is nil")
|
||||||
|
}
|
||||||
|
proxyURL, ok := req.Context().Value("proxy").(string)
|
||||||
|
if !ok || proxyURL == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
parsedURL, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
|
||||||
|
}
|
||||||
|
return parsedURL, nil
|
||||||
|
}
|
||||||
|
}
|
698
curl_test.go
Normal file
698
curl_test.go
Normal file
@ -0,0 +1,698 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUrlEncodeRaw(t *testing.T) {
|
||||||
|
input := "hello world!@#$%^&*()_+-=~`"
|
||||||
|
expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60"
|
||||||
|
result := UrlEncodeRaw(input)
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUrlEncode(t *testing.T) {
|
||||||
|
input := "hello world!@#$%^&*()_+-=~`"
|
||||||
|
expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60`
|
||||||
|
result := UrlEncode(input)
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUrlDecode(t *testing.T) {
|
||||||
|
input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60"
|
||||||
|
expected := "hello world!@#$%^&*()_+-=~`"
|
||||||
|
result, err := UrlDecode(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("UrlDecode(%q) returned error: %v", input, err)
|
||||||
|
}
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test for error case
|
||||||
|
invalidInput := "%zz"
|
||||||
|
_, err = UrlDecode(invalidInput)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPostForm_WithValidInput(t *testing.T) {
|
||||||
|
input := map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []byte("key1=value1&key2=value2")
|
||||||
|
|
||||||
|
result := BuildPostForm(input)
|
||||||
|
|
||||||
|
if string(result) != string(expected) {
|
||||||
|
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPostForm_WithEmptyInput(t *testing.T) {
|
||||||
|
input := map[string]string{}
|
||||||
|
|
||||||
|
expected := []byte("")
|
||||||
|
|
||||||
|
result := BuildPostForm(input)
|
||||||
|
|
||||||
|
if string(result) != string(expected) {
|
||||||
|
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPostForm_WithNilInput(t *testing.T) {
|
||||||
|
var input map[string]string
|
||||||
|
|
||||||
|
expected := []byte("")
|
||||||
|
|
||||||
|
result := BuildPostForm(input)
|
||||||
|
|
||||||
|
if string(result) != string(expected) {
|
||||||
|
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRequest(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL, WithSkipTLSVerify(true), WithHeader("hello", "world"), WithUserAgent("hello world"))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostRequest(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
t.Errorf("Expected 'POST', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Post(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOptionsRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodOptions {
|
||||||
|
t.Errorf("Expected 'OPTIONS', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Options(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPutRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodPut {
|
||||||
|
t.Errorf("Expected 'PUT', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Put(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodDelete {
|
||||||
|
t.Errorf("Expected 'DELETE', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Delete(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodHead {
|
||||||
|
t.Errorf("Expected 'HEAD', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Head(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body == "OK" {
|
||||||
|
t.Errorf("Expected , got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPatchRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodPatch {
|
||||||
|
t.Errorf("Expected 'PATCH', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Patch(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodTrace {
|
||||||
|
t.Errorf("Expected 'TRACE', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Trace(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectRequestWithValidInput(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodConnect {
|
||||||
|
t.Errorf("Expected 'CONNECT', got %v", req.Method)
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Connect(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestMethodReturnsCorrectValue(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
req.SetMethodNoError("GET")
|
||||||
|
if req.Method() != "GET" {
|
||||||
|
t.Errorf("Expected 'GET', got %v", req.Method())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMethodHandlesInvalidInput(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
err := req.SetMethod("我是谁")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
req.SetMethodNoError("POST")
|
||||||
|
if req.Method() != "POST" {
|
||||||
|
t.Errorf("Expected 'POST', got %v", req.Method())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
req.SetMethodNoError("你是谁")
|
||||||
|
if req.Method() != "GET" {
|
||||||
|
t.Errorf("Expected '', got %v", req.Method())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUriReturnsCorrectValue(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
if req.Uri() != "https://example.com" {
|
||||||
|
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUriHandlesValidInput(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
err := req.SetUri("https://newexample.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if req.Uri() != "https://newexample.com" {
|
||||||
|
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUriHandlesInvalidInput(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
err := req.SetUri("://invalidurl")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
req.SetUriNoError("https://newexample.com")
|
||||||
|
if req.Uri() != "https://newexample.com" {
|
||||||
|
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||||
|
req := NewReq("https://example.com")
|
||||||
|
req.SetUriNoError("://invalidurl")
|
||||||
|
if req.Uri() != "https://example.com" {
|
||||||
|
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type postmanReply struct {
|
||||||
|
Args struct {
|
||||||
|
} `json:"args"`
|
||||||
|
Form map[string]string `json:"form"`
|
||||||
|
Headers map[string]string `json:"headers"`
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
var reply postmanReply
|
||||||
|
resp, err := NewReq("https://postman-echo.com/get").
|
||||||
|
AddHeader("hello", "nononmo").
|
||||||
|
SetAutoCalcContentLengthNoError(true).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
fmt.Println(resp.Proto)
|
||||||
|
err = resp.Body().Unmarshal(&reply)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
fmt.Println(resp.Body().String())
|
||||||
|
fmt.Println(reply.Headers)
|
||||||
|
fmt.Println(resp.Cookies())
|
||||||
|
}
|
||||||
|
|
||||||
|
type testData struct {
|
||||||
|
name string
|
||||||
|
args *Request
|
||||||
|
want func(*Response) error
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerTestData() []testData {
|
||||||
|
return []testData{
|
||||||
|
{
|
||||||
|
name: "addHeader",
|
||||||
|
args: NewReq("https://postman-echo.com/get").
|
||||||
|
AddHeader("b612", "test-data").
|
||||||
|
AddHeader("b612", "test-header").
|
||||||
|
AddSimpleCookie("b612", "test-cookie").
|
||||||
|
SetHeader("User-Agent", "starnet test"),
|
||||||
|
want: func(resp *Response) error {
|
||||||
|
//fmt.Println(resp.Body().String())
|
||||||
|
if resp == nil {
|
||||||
|
return fmt.Errorf("response is nil")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var reply postmanReply
|
||||||
|
err := resp.Body().Unmarshal(&reply)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if reply.Headers["b612"] != "test-data, test-header" {
|
||||||
|
return fmt.Errorf("header not found")
|
||||||
|
}
|
||||||
|
if reply.Headers["user-agent"] != "starnet test" {
|
||||||
|
return fmt.Errorf("user-agent not found")
|
||||||
|
}
|
||||||
|
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||||
|
return fmt.Errorf("cookie not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "postForm",
|
||||||
|
args: NewSimpleRequest("https://postman-echo.com/post", "POST").
|
||||||
|
AddHeader("b612", "test-data").
|
||||||
|
AddHeader("b612", "test-header").
|
||||||
|
AddSimpleCookie("b612", "test-cookie").
|
||||||
|
SetHeader("User-Agent", "starnet test").
|
||||||
|
//SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
AddFormData("hello", "world").
|
||||||
|
AddFormData("hello2", "world2").
|
||||||
|
SetMethodNoError("POST"),
|
||||||
|
want: func(resp *Response) error {
|
||||||
|
//fmt.Println(resp.Body().String())
|
||||||
|
if resp == nil {
|
||||||
|
return fmt.Errorf("response is nil")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
var reply postmanReply
|
||||||
|
err := resp.Body().Unmarshal(&reply)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if reply.Headers["b612"] != "test-data, test-header" {
|
||||||
|
return fmt.Errorf("header not found")
|
||||||
|
}
|
||||||
|
if reply.Headers["user-agent"] != "starnet test" {
|
||||||
|
return fmt.Errorf("user-agent not found")
|
||||||
|
}
|
||||||
|
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||||
|
return fmt.Errorf("cookie not found")
|
||||||
|
}
|
||||||
|
if reply.Form["hello"] != "world" {
|
||||||
|
return fmt.Errorf("form data not found")
|
||||||
|
}
|
||||||
|
if reply.Form["hello2"] != "world2" {
|
||||||
|
return fmt.Errorf("form data not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestCurl(t *testing.T) {
|
||||||
|
for _, tt := range headerTestData() {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Curl(tt.args)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.want != nil {
|
||||||
|
if err := tt.want(got); err != nil {
|
||||||
|
t.Errorf("Curl() = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReqClone(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Header.Get("hello") != "world" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
rw.Write([]byte("hello world failed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
req = req.Clone()
|
||||||
|
req.AddHeader("ok", "good")
|
||||||
|
resp, err = req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadFile(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Header.Get("hello") != "world" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
rw.Write([]byte("hello world failed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
files, header, err := req.FormFile("666")
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println(header.Filename)
|
||||||
|
fmt.Println(header.Size)
|
||||||
|
fmt.Println(files.Close())
|
||||||
|
}
|
||||||
|
files, header, err = req.FormFile("777")
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println(header.Filename)
|
||||||
|
fmt.Println(header.Size)
|
||||||
|
fmt.Println(files.Close())
|
||||||
|
}
|
||||||
|
files, header, err = req.FormFile("888")
|
||||||
|
if err == nil {
|
||||||
|
fmt.Println(header.Filename)
|
||||||
|
fmt.Println(header.Size)
|
||||||
|
fmt.Println(files.Close())
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||||
|
req.AddFileWithName("666", "./curl.go", "curl.go")
|
||||||
|
req.AddFile("777", "./go.mod")
|
||||||
|
req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html")
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
req = req.Clone()
|
||||||
|
req.AddHeader("ok", "good")
|
||||||
|
|
||||||
|
resp, err = req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTlsConfig(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Header.Get("hello") != "world" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
rw.Write([]byte("hello world failed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
client, err := NewHttpClient(WithSkipTLSVerify(false))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world"))
|
||||||
|
//SetClientSkipVerify(client, true)
|
||||||
|
//req.SetDoRawClient(false)
|
||||||
|
//req.SetDoRawTransport(false)
|
||||||
|
req.SetSkipTLSVerify(true)
|
||||||
|
req.SetProxy("http://127.0.0.1:29992")
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
fmt.Println(resp.Proto)
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
req = req.Clone()
|
||||||
|
req.AddHeader("ok", "good")
|
||||||
|
resp, err = req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
req = req.Clone()
|
||||||
|
req.SetSkipTLSVerify(false)
|
||||||
|
resp, err = req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpPostAndChunked(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
t.Errorf("Expected 'POST', got %v", req.Method)
|
||||||
|
}
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := req.Body.Read(buf)
|
||||||
|
if string(buf[:n]) != "hello world" {
|
||||||
|
t.Errorf("Expected body to be 'hello world', got %s", string(buf[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Header.Get("chunked") == "true" {
|
||||||
|
if req.TransferEncoding[0] != "chunked" {
|
||||||
|
t.Errorf("Expected Transfer-Encoding to be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
|
||||||
|
t.Errorf("Expected Transfer-Encoding to not be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Post(server.URL, WithBytes([]byte("hello world")), WithContentLength(-1), WithHeader("Content-Type", "text/plain"),
|
||||||
|
WithHeader("chunked", "true"))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
|
||||||
|
resp, err = Post(server.URL, WithBytes([]byte("hello world")), WithHeader("Content-Type", "text/plain"),
|
||||||
|
WithHeader("chunked", "false"))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
body = resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTimeout(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
time.Sleep(time.Second * 30)
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
funcList := []func(string, ...RequestOpt) (*Response, error){
|
||||||
|
Get,
|
||||||
|
Post,
|
||||||
|
Put,
|
||||||
|
Delete,
|
||||||
|
Options,
|
||||||
|
Patch,
|
||||||
|
Head,
|
||||||
|
Trace,
|
||||||
|
Connect,
|
||||||
|
}
|
||||||
|
defer server.Close()
|
||||||
|
for i := 1; i < 30; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
old := time.Now()
|
||||||
|
fn := funcList[i%len(funcList)]
|
||||||
|
resp, err := fn(server.URL, WithTimeout(time.Second*time.Duration(i)))
|
||||||
|
if time.Since(old) > time.Second*time.Duration(i+2) || time.Since(old) < time.Second*time.Duration(i) {
|
||||||
|
t.Errorf("timeout not work")
|
||||||
|
}
|
||||||
|
fmt.Println(time.Since(old))
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
resp.CloseAll()
|
||||||
|
} else {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
resp, err := Get(server.URL, WithTimeout(time.Second*60))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
fmt.Println(resp.Body().String())
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.CloseAll()
|
||||||
|
t.Errorf("status code is %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
resp.CloseAll()
|
||||||
|
}
|
||||||
|
}
|
165
curl_transport.go
Normal file
165
curl_transport.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
*http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHttpClient creates a new http.Client with the specified options.
|
||||||
|
func NewHttpClient(opts ...RequestOpt) (Client, error) {
|
||||||
|
req, err := newRequest(context.Background(), "", "", opts...)
|
||||||
|
if err != nil {
|
||||||
|
return Client{}, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
req = nil
|
||||||
|
}()
|
||||||
|
cl, err := req.HttpClient()
|
||||||
|
return Client{
|
||||||
|
Client: cl,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHttpClientNoErr(opts ...RequestOpt) Client {
|
||||||
|
c, _ := NewHttpClient(opts...)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientFromHttpClient(httpClient *http.Client) (Client, error) {
|
||||||
|
if httpClient == nil {
|
||||||
|
return Client{}, fmt.Errorf("httpClient cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpClient.Transport == nil {
|
||||||
|
httpClient.Transport = &Transport{
|
||||||
|
base: &http.Transport{},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch t := httpClient.Transport.(type) {
|
||||||
|
case *Transport:
|
||||||
|
if t.base == nil {
|
||||||
|
t.base = &http.Transport{}
|
||||||
|
}
|
||||||
|
case *http.Transport:
|
||||||
|
httpClient.Transport = &Transport{
|
||||||
|
base: t,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return Client{}, fmt.Errorf("unsupported transport type: %T", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Client{
|
||||||
|
Client: httpClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientFromHttpClientNoError(httpClient *http.Client) Client {
|
||||||
|
return Client{Client: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableRedirect returns whether the request will disable HTTP redirects.
|
||||||
|
// if true, the request will not follow redirects automatically.
|
||||||
|
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||||
|
// you will get the original response with the redirect status code and Location header.
|
||||||
|
func (c Client) DisableRedirect() bool {
|
||||||
|
return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableRedirect sets whether the request will disable HTTP redirects.
|
||||||
|
// if true, the request will not follow redirects automatically.
|
||||||
|
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||||
|
// you will get the original response with the redirect status code and Location header.
|
||||||
|
func (c Client) SetDisableRedirect(disableRedirect bool) {
|
||||||
|
if disableRedirect {
|
||||||
|
c.Client.CheckRedirect = DefaultCheckRedirectFunc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) SetDefaultSkipTLSVerify(skip bool) {
|
||||||
|
if c.Client.Transport == nil {
|
||||||
|
c.Client.Transport = &Transport{
|
||||||
|
base: &http.Transport{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||||
|
if transport.base.TLSClientConfig == nil {
|
||||||
|
transport.base.TLSClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||||
|
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||||
|
if transport.TLSClientConfig == nil {
|
||||||
|
transport.TLSClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
transport.TLSClientConfig.InsecureSkipVerify = skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) {
|
||||||
|
if c.Client.Transport == nil {
|
||||||
|
c.Client.Transport = &Transport{
|
||||||
|
base: &http.Transport{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||||
|
transport.base.TLSClientConfig = tlsConfig
|
||||||
|
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||||
|
transport.TLSClientConfig = tlsConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
if c.Client == nil {
|
||||||
|
return nil, fmt.Errorf("http client is nil")
|
||||||
|
}
|
||||||
|
req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...)
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
if c.Client == nil {
|
||||||
|
return nil, fmt.Errorf("http client is nil")
|
||||||
|
}
|
||||||
|
req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...)
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||||
|
req, _ := c.NewRequest(url, method, opts...)
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||||
|
req, _ := c.NewRequestContext(ctx, url, method, opts...)
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
type Transport struct {
|
||||||
|
base *http.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if t.base == nil {
|
||||||
|
t.base = &http.Transport{}
|
||||||
|
}
|
||||||
|
transport, ok := req.Context().Value("transport").(*http.Transport)
|
||||||
|
if ok && transport != nil {
|
||||||
|
return transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
proxy, ok := req.Context().Value("proxy").(string)
|
||||||
|
if ok && proxy != "" {
|
||||||
|
tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config)
|
||||||
|
if ok && tlsConfig != nil {
|
||||||
|
tmpTransport := t.base.Clone()
|
||||||
|
tmpTransport.TLSClientConfig = tlsConfig
|
||||||
|
return tmpTransport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
|
}
|
198
curlbench_test.go
Normal file
198
curlbench_test.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkGetRequest 测试单个 GET 请求的性能
|
||||||
|
func BenchmarkGetRequest(b *testing.B) {
|
||||||
|
// 创建测试服务器
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// 重置计时器,排除设置代码的影响
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
// 报告内存分配情况
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
// 运行基准测试
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
b.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkGetRequestWithHeaders 测试带请求头的 GET 请求性能
|
||||||
|
func BenchmarkGetRequestWithHeaders(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// 验证请求头
|
||||||
|
if req.Header.Get("hello") != "world" {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL,
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
WithHeader("hello", "world"),
|
||||||
|
WithUserAgent("hello world"))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
b.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkPostRequest 测试 POST 请求的性能
|
||||||
|
func BenchmarkPostRequest(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// 读取并返回请求体
|
||||||
|
body := make([]byte, req.ContentLength)
|
||||||
|
req.Body.Read(body)
|
||||||
|
rw.Write(body)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := "This is a test payload for POST request"
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Post(server.URL,
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
WithBytes([]byte(testData)),
|
||||||
|
WithContentType("text/plain"))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != testData {
|
||||||
|
b.Errorf("Expected %s, got %v", testData, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkConcurrentRequests 测试并发请求性能
|
||||||
|
func BenchmarkConcurrentRequests(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
// 运行并发基准测试
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
b.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkMemoryUsage 专门测试内存使用情况
|
||||||
|
func BenchmarkMemoryUsage(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write([]byte(`OK`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// 禁用默认的测试时间,只关注内存分配
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
var memStatsStart, memStatsEnd runtime.MemStats
|
||||||
|
runtime.GC()
|
||||||
|
runtime.ReadMemStats(&memStatsStart)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
b.Errorf("Expected OK, got %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.GC()
|
||||||
|
runtime.ReadMemStats(&memStatsEnd)
|
||||||
|
|
||||||
|
// 计算每次操作的平均内存分配
|
||||||
|
allocsPerOp := float64(memStatsEnd.Mallocs-memStatsStart.Mallocs) / float64(b.N)
|
||||||
|
bytesPerOp := float64(memStatsEnd.TotalAlloc-memStatsStart.TotalAlloc) / float64(b.N)
|
||||||
|
|
||||||
|
b.ReportMetric(allocsPerOp, "allocs/op")
|
||||||
|
b.ReportMetric(bytesPerOp, "bytes/op")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDifferentResponseSizes 测试不同响应大小的性能
|
||||||
|
func BenchmarkDifferentResponseSizes(b *testing.B) {
|
||||||
|
// 测试不同大小的响应
|
||||||
|
responseSizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
|
||||||
|
|
||||||
|
for _, size := range responseSizes {
|
||||||
|
// 生成指定大小的响应数据
|
||||||
|
responseData := make([]byte, size)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
responseData[i] = 'A'
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rw.Write(responseData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := resp.Body().Bytes()
|
||||||
|
if len(body) != size {
|
||||||
|
b.Errorf("Expected size %d, got %d", size, len(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
2
go.mod
2
go.mod
@ -1,5 +1,3 @@
|
|||||||
module b612.me/starnet
|
module b612.me/starnet
|
||||||
|
|
||||||
go 1.16
|
go 1.16
|
||||||
|
|
||||||
require b612.me/stario v0.0.5
|
|
||||||
|
13
go.sum
13
go.sum
@ -1,13 +0,0 @@
|
|||||||
b612.me/stario v0.0.5 h1:Q1OGF+8eOoK49zMzkyh80GWaMuknhey6+PWJJL9ZuNo=
|
|
||||||
b612.me/stario v0.0.5/go.mod h1:or4ssWcxQSjMeu+hRKEgtp0X517b3zdlEOAms8Qscvw=
|
|
||||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
|
|
||||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
120
httpguts.go
Normal file
120
httpguts.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var isTokenTable = [127]bool{
|
||||||
|
'!': true,
|
||||||
|
'#': true,
|
||||||
|
'$': true,
|
||||||
|
'%': true,
|
||||||
|
'&': true,
|
||||||
|
'\'': true,
|
||||||
|
'*': true,
|
||||||
|
'+': true,
|
||||||
|
'-': true,
|
||||||
|
'.': true,
|
||||||
|
'0': true,
|
||||||
|
'1': true,
|
||||||
|
'2': true,
|
||||||
|
'3': true,
|
||||||
|
'4': true,
|
||||||
|
'5': true,
|
||||||
|
'6': true,
|
||||||
|
'7': true,
|
||||||
|
'8': true,
|
||||||
|
'9': true,
|
||||||
|
'A': true,
|
||||||
|
'B': true,
|
||||||
|
'C': true,
|
||||||
|
'D': true,
|
||||||
|
'E': true,
|
||||||
|
'F': true,
|
||||||
|
'G': true,
|
||||||
|
'H': true,
|
||||||
|
'I': true,
|
||||||
|
'J': true,
|
||||||
|
'K': true,
|
||||||
|
'L': true,
|
||||||
|
'M': true,
|
||||||
|
'N': true,
|
||||||
|
'O': true,
|
||||||
|
'P': true,
|
||||||
|
'Q': true,
|
||||||
|
'R': true,
|
||||||
|
'S': true,
|
||||||
|
'T': true,
|
||||||
|
'U': true,
|
||||||
|
'W': true,
|
||||||
|
'V': true,
|
||||||
|
'X': true,
|
||||||
|
'Y': true,
|
||||||
|
'Z': true,
|
||||||
|
'^': true,
|
||||||
|
'_': true,
|
||||||
|
'`': true,
|
||||||
|
'a': true,
|
||||||
|
'b': true,
|
||||||
|
'c': true,
|
||||||
|
'd': true,
|
||||||
|
'e': true,
|
||||||
|
'f': true,
|
||||||
|
'g': true,
|
||||||
|
'h': true,
|
||||||
|
'i': true,
|
||||||
|
'j': true,
|
||||||
|
'k': true,
|
||||||
|
'l': true,
|
||||||
|
'm': true,
|
||||||
|
'n': true,
|
||||||
|
'o': true,
|
||||||
|
'p': true,
|
||||||
|
'q': true,
|
||||||
|
'r': true,
|
||||||
|
's': true,
|
||||||
|
't': true,
|
||||||
|
'u': true,
|
||||||
|
'v': true,
|
||||||
|
'w': true,
|
||||||
|
'x': true,
|
||||||
|
'y': true,
|
||||||
|
'z': true,
|
||||||
|
'|': true,
|
||||||
|
'~': true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsTokenRune(r rune) bool {
|
||||||
|
i := int(r)
|
||||||
|
return i < len(isTokenTable) && isTokenTable[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func validMethod(method string) bool {
|
||||||
|
/*
|
||||||
|
Method = "OPTIONS" ; Section 9.2
|
||||||
|
| "GET" ; Section 9.3
|
||||||
|
| "HEAD" ; Section 9.4
|
||||||
|
| "POST" ; Section 9.5
|
||||||
|
| "PUT" ; Section 9.6
|
||||||
|
| "DELETE" ; Section 9.7
|
||||||
|
| "TRACE" ; Section 9.8
|
||||||
|
| "CONNECT" ; Section 9.9
|
||||||
|
| extension-method
|
||||||
|
extension-method = token
|
||||||
|
token = 1*<any CHAR except CTLs or separators>
|
||||||
|
*/
|
||||||
|
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNotToken(r rune) bool {
|
||||||
|
return !IsTokenRune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
|
||||||
|
|
||||||
|
// removeEmptyPort strips the empty port in ":port" to ""
|
||||||
|
// as mandated by RFC 3986 Section 6.2.3.
|
||||||
|
func removeEmptyPort(host string) string {
|
||||||
|
if hasPort(host) {
|
||||||
|
return strings.TrimSuffix(host, ":")
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
13
ping.go
13
ping.go
@ -33,6 +33,7 @@ func getICMP(seq uint16) ICMP {
|
|||||||
|
|
||||||
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
|
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
|
||||||
var res PingResult
|
var res PingResult
|
||||||
|
res.RemoteIP = destAddr.String()
|
||||||
conn, err := net.DialIP("ip:icmp", nil, destAddr)
|
conn, err := net.DialIP("ip:icmp", nil, destAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
@ -84,6 +85,7 @@ func checkSum(data []byte) uint16 {
|
|||||||
type PingResult struct {
|
type PingResult struct {
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
RecvCount int
|
RecvCount int
|
||||||
|
RemoteIP string
|
||||||
}
|
}
|
||||||
|
|
||||||
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
||||||
@ -95,3 +97,14 @@ func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
|||||||
icmp := getICMP(uint16(seq))
|
icmp := getICMP(uint16(seq))
|
||||||
return sendICMPRequest(icmp, ipAddr, timeout)
|
return sendICMPRequest(icmp, ipAddr, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
||||||
|
for i := 0; i < retryLimit; i++ {
|
||||||
|
_, err := Ping(ip, 29, timeout)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -7,5 +7,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func Test_Ping(t *testing.T) {
|
func Test_Ping(t *testing.T) {
|
||||||
fmt.Println(Ping("baidu.com", 0, time.Second*2))
|
fmt.Println(Ping("baidu.com", 29, time.Second*2))
|
||||||
|
fmt.Println(Ping("www.b612.me", 29, time.Second*2))
|
||||||
|
fmt.Println(IsIpPingable("baidu.com", time.Second*2, 3))
|
||||||
|
fmt.Println(IsIpPingable("www.b612.me", time.Second*2, 3))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
317
que.go
317
que.go
@ -1,317 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 识别头
|
|
||||||
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
|
||||||
|
|
||||||
// MsgQueue 为基本的信息单位
|
|
||||||
type MsgQueue struct {
|
|
||||||
ID uint16
|
|
||||||
Msg []byte
|
|
||||||
Conn interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StarQueue 为流数据中的消息队列分发
|
|
||||||
type StarQueue struct {
|
|
||||||
count int64
|
|
||||||
Encode bool
|
|
||||||
Reserve uint16
|
|
||||||
Msgid uint16
|
|
||||||
MsgPool chan MsgQueue
|
|
||||||
UnFinMsg sync.Map
|
|
||||||
LastID int //= -1
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
duration time.Duration
|
|
||||||
EncodeFunc func([]byte) []byte
|
|
||||||
DecodeFunc func([]byte) []byte
|
|
||||||
//restoreMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewQueueCtx(ctx context.Context, count int64) *StarQueue {
|
|
||||||
var que StarQueue
|
|
||||||
que.Encode = false
|
|
||||||
que.count = count
|
|
||||||
que.MsgPool = make(chan MsgQueue, count)
|
|
||||||
if ctx == nil {
|
|
||||||
que.ctx, que.cancel = context.WithCancel(context.Background())
|
|
||||||
} else {
|
|
||||||
que.ctx, que.cancel = context.WithCancel(ctx)
|
|
||||||
}
|
|
||||||
que.duration = 0
|
|
||||||
return &que
|
|
||||||
}
|
|
||||||
func NewQueueWithCount(count int64) *StarQueue {
|
|
||||||
return NewQueueCtx(nil, count)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQueue 建立一个新消息队列
|
|
||||||
func NewQueue() *StarQueue {
|
|
||||||
return NewQueueWithCount(32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint32ToByte 4位uint32转byte
|
|
||||||
func Uint32ToByte(src uint32) []byte {
|
|
||||||
res := make([]byte, 4)
|
|
||||||
res[3] = uint8(src)
|
|
||||||
res[2] = uint8(src >> 8)
|
|
||||||
res[1] = uint8(src >> 16)
|
|
||||||
res[0] = uint8(src >> 24)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint32 byte转4位uint32
|
|
||||||
func ByteToUint32(src []byte) uint32 {
|
|
||||||
var res uint32
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint16ToByte 2位uint16转byte
|
|
||||||
func Uint16ToByte(src uint16) []byte {
|
|
||||||
res := make([]byte, 2)
|
|
||||||
res[1] = uint8(src)
|
|
||||||
res[0] = uint8(src >> 8)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint16 用于byte转uint16
|
|
||||||
func ByteToUint16(src []byte) uint16 {
|
|
||||||
var res uint16
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildMessage 生成编码后的信息用于发送
|
|
||||||
func (que *StarQueue) BuildMessage(src []byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
que.Msgid++
|
|
||||||
if que.Encode {
|
|
||||||
src = que.EncodeFunc(src)
|
|
||||||
}
|
|
||||||
length := uint32(len(src))
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(que.Msgid))
|
|
||||||
buff.Write(src)
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildHeader 生成编码后的Header用于发送
|
|
||||||
func (que *StarQueue) BuildHeader(length uint32) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
que.Msgid++
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(que.Msgid))
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
type unFinMsg struct {
|
|
||||||
ID uint16
|
|
||||||
LengthRecv uint32
|
|
||||||
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
|
|
||||||
HeaderMsg []byte
|
|
||||||
RecvMsg []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (que *StarQueue) push2list(msg MsgQueue) {
|
|
||||||
que.MsgPool <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseMessage 用于解析收到的msg信息
|
|
||||||
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMessage 用于解析收到的msg信息
|
|
||||||
func (que *StarQueue) parseMessage(msg []byte, conn interface{}) error {
|
|
||||||
tmp, ok := que.UnFinMsg.Load(conn)
|
|
||||||
if ok { //存在未完成的信息
|
|
||||||
lastMsg := tmp.(*unFinMsg)
|
|
||||||
headerLen := len(lastMsg.HeaderMsg)
|
|
||||||
if headerLen < 14 { //未完成头标题
|
|
||||||
//传输的数据不能填充header头
|
|
||||||
if len(msg) < 14-headerLen {
|
|
||||||
//加入header头并退出
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
//获取14字节完整的header
|
|
||||||
header := msg[0 : 14-headerLen]
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
|
|
||||||
//检查收到的header是否为认证header
|
|
||||||
//若不是,丢弃并重新来过
|
|
||||||
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
//获得本数据包长度
|
|
||||||
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
|
|
||||||
//获得本数据包ID
|
|
||||||
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
|
|
||||||
//存入列表
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
msg = msg[14-headerLen:]
|
|
||||||
if uint32(len(msg)) < lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if uint32(len(msg)) >= lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
|
|
||||||
if que.Encode {
|
|
||||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
msg = msg[lastMsg.LengthRecv:]
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
//que.restoreMu.Lock()
|
|
||||||
que.push2list(storeMsg)
|
|
||||||
//que.restoreMu.Unlock()
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
|
|
||||||
if lastID < 0 {
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
if len(msg) >= lastID {
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
|
|
||||||
if que.Encode {
|
|
||||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
que.push2list(storeMsg)
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
if len(msg) == lastID {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
msg = msg[lastID:]
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var start int
|
|
||||||
if start = searchHeader(msg); start == -1 {
|
|
||||||
return errors.New("data format error")
|
|
||||||
}
|
|
||||||
msg = msg[start:]
|
|
||||||
lastMsg := unFinMsg{}
|
|
||||||
que.UnFinMsg.Store(conn, &lastMsg)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkHeader(msg []byte) bool {
|
|
||||||
if len(msg) != 8 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
if v != header[k] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func searchHeader(msg []byte) int {
|
|
||||||
if len(msg) < 8 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
find := 0
|
|
||||||
if v == header[0] {
|
|
||||||
for k2, v2 := range header {
|
|
||||||
if msg[k+k2] == v2 {
|
|
||||||
find++
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if find == 8 {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func bytesMerge(src ...[]byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
for _, v := range src {
|
|
||||||
buff.Write(v)
|
|
||||||
}
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore 获取收到的信息
|
|
||||||
func (que *StarQueue) Restore() (MsgQueue, error) {
|
|
||||||
if que.duration.Seconds() == 0 {
|
|
||||||
que.duration = 86400 * time.Second
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-que.ctx.Done():
|
|
||||||
return MsgQueue{}, errors.New("Stoped By External Function Call")
|
|
||||||
case <-time.After(que.duration):
|
|
||||||
if que.duration != 0 {
|
|
||||||
return MsgQueue{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
case data, ok := <-que.MsgPool:
|
|
||||||
if !ok {
|
|
||||||
return MsgQueue{}, os.ErrClosed
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreOne 获取收到的一个信息
|
|
||||||
//兼容性修改
|
|
||||||
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
|
|
||||||
return que.Restore()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop 立即停止Restore
|
|
||||||
func (que *StarQueue) Stop() {
|
|
||||||
que.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreDuration Restore最大超时时间
|
|
||||||
func (que *StarQueue) RestoreDuration(tm time.Duration) {
|
|
||||||
que.duration = tm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (que *StarQueue) RestoreChan() <-chan MsgQueue {
|
|
||||||
return que.MsgPool
|
|
||||||
}
|
|
42
que_test.go
42
que_test.go
@ -1,42 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_QueSpeed(t *testing.T) {
|
|
||||||
que := NewQueueWithCount(0)
|
|
||||||
stop := make(chan struct{}, 1)
|
|
||||||
que.RestoreDuration(time.Second * 10)
|
|
||||||
var count int64
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
//fmt.Println(count)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
_, err := que.RestoreOne()
|
|
||||||
if err == nil {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
cp := 0
|
|
||||||
stoped := time.After(time.Second * 10)
|
|
||||||
data := que.BuildMessage([]byte("hello"))
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stoped:
|
|
||||||
fmt.Println(count, cp)
|
|
||||||
stop <- struct{}{}
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
que.ParseMessage(data, "lala")
|
|
||||||
cp++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
401
tlssniffer.go
Normal file
401
tlssniffer.go
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type myConn struct {
|
||||||
|
reader io.Reader
|
||||||
|
conn net.Conn
|
||||||
|
isReadOnly bool
|
||||||
|
multiReader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *myConn) Read(p []byte) (int, error) {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return c.reader.Read(p)
|
||||||
|
}
|
||||||
|
if c.multiReader == nil {
|
||||||
|
c.multiReader = io.MultiReader(c.reader, c.conn)
|
||||||
|
}
|
||||||
|
return c.multiReader.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *myConn) Write(p []byte) (int, error) {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return c.conn.Write(p)
|
||||||
|
}
|
||||||
|
func (c *myConn) Close() error {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
func (c *myConn) LocalAddr() net.Addr {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
func (c *myConn) RemoteAddr() net.Addr {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
func (c *myConn) SetDeadline(t time.Time) error {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *myConn) SetReadDeadline(t time.Time) error {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if c.isReadOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Listener struct {
|
||||||
|
net.Listener
|
||||||
|
cfg *tls.Config
|
||||||
|
getConfigForClient func(hostname string) *tls.Config
|
||||||
|
allowNonTls bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
||||||
|
return l.getConfigForClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
||||||
|
l.getConfigForClient = getConfigForClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func Listen(network, address string) (*Listener, error) {
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Listener{Listener: listener}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||||
|
listener, err := liscfg.Listen(context.Background(), network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Listener{
|
||||||
|
Listener: listener,
|
||||||
|
cfg: config,
|
||||||
|
getConfigForClient: getConfigForClient,
|
||||||
|
allowNonTls: allowNonTls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||||
|
return &Listener{
|
||||||
|
Listener: listener,
|
||||||
|
cfg: config,
|
||||||
|
getConfigForClient: getConfigForClient,
|
||||||
|
allowNonTls: allowNonTls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Listener{
|
||||||
|
Listener: listener,
|
||||||
|
cfg: config,
|
||||||
|
getConfigForClient: getConfigForClient,
|
||||||
|
allowNonTls: allowNonTls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
||||||
|
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{config},
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Listener{
|
||||||
|
Listener: listener,
|
||||||
|
cfg: tlsConfig,
|
||||||
|
allowNonTls: allowNonTls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Accept() (net.Conn, error) {
|
||||||
|
conn, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
tlsCfg: l.cfg,
|
||||||
|
getConfigForClient: l.getConfigForClient,
|
||||||
|
allowNonTls: l.allowNonTls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
once sync.Once
|
||||||
|
initErr error
|
||||||
|
isTLS bool
|
||||||
|
tlsCfg *tls.Config
|
||||||
|
tlsConn *tls.Conn
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
noTlsReader io.Reader
|
||||||
|
isOriginal bool
|
||||||
|
getConfigForClient func(hostname string) *tls.Config
|
||||||
|
hostname string
|
||||||
|
allowNonTls bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Hostname() string {
|
||||||
|
if c.hostname != "" {
|
||||||
|
return c.hostname
|
||||||
|
}
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
if c.tlsConn.ConnectionState().ServerName != "" {
|
||||||
|
c.hostname = c.tlsConn.ConnectionState().ServerName
|
||||||
|
return c.hostname
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) IsTLS() bool {
|
||||||
|
return c.isTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) TlsConn() *tls.Conn {
|
||||||
|
return c.tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) isTLSConnection() (bool, error) {
|
||||||
|
if c.getConfigForClient == nil {
|
||||||
|
peek := make([]byte, 5)
|
||||||
|
n, err := io.ReadFull(c.Conn, peek)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||||
|
|
||||||
|
c.buffer = bytes.NewBuffer(peek[:n])
|
||||||
|
return isTLS, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.buffer = new(bytes.Buffer)
|
||||||
|
r := io.TeeReader(c.Conn, c.buffer)
|
||||||
|
var hello *tls.ClientHelloInfo
|
||||||
|
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
||||||
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
hello = new(tls.ClientHelloInfo)
|
||||||
|
*hello = *argHello
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}).Handshake()
|
||||||
|
peek := c.buffer.Bytes()
|
||||||
|
n := len(peek)
|
||||||
|
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||||
|
if hello == nil {
|
||||||
|
return isTLS, nil
|
||||||
|
}
|
||||||
|
c.hostname = hello.ServerName
|
||||||
|
if c.hostname == "" {
|
||||||
|
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
|
||||||
|
}
|
||||||
|
return isTLS, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) init() {
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.isOriginal {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.tlsCfg != nil {
|
||||||
|
isTLS, err := c.isTLSConnection()
|
||||||
|
if err != nil {
|
||||||
|
c.initErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.isTLS = isTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.isTLS {
|
||||||
|
var cfg = c.tlsCfg
|
||||||
|
if c.getConfigForClient != nil {
|
||||||
|
cfg = c.getConfigForClient(c.hostname)
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = c.tlsCfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.tlsConn = tls.Server(&myConn{
|
||||||
|
reader: c.buffer,
|
||||||
|
conn: c.Conn,
|
||||||
|
isReadOnly: false,
|
||||||
|
}, cfg)
|
||||||
|
} else {
|
||||||
|
if !c.allowNonTls {
|
||||||
|
c.initErr = net.ErrClosed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(b []byte) (int, error) {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return 0, c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS {
|
||||||
|
return c.tlsConn.Read(b)
|
||||||
|
}
|
||||||
|
return c.noTlsReader.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return 0, c.initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.isTLS {
|
||||||
|
return c.tlsConn.Write(b)
|
||||||
|
}
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.Close()
|
||||||
|
}
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
return c.Conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
return c.Conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
return c.Conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
||||||
|
if c.initErr != nil {
|
||||||
|
return nil, c.initErr
|
||||||
|
}
|
||||||
|
if !c.isTLS {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
return c.tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) OriginalConn() net.Conn {
|
||||||
|
return c.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||||
|
if conn == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
c := &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
isTLS: true,
|
||||||
|
tlsCfg: cfg,
|
||||||
|
tlsConn: tls.Client(conn, cfg),
|
||||||
|
isOriginal: true,
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||||
|
if conn == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
c := &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
isTLS: true,
|
||||||
|
tlsCfg: cfg,
|
||||||
|
tlsConn: tls.Server(conn, cfg),
|
||||||
|
isOriginal: true,
|
||||||
|
}
|
||||||
|
c.init()
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dial(network, address string) (*Conn, error) {
|
||||||
|
conn, err := net.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Conn{
|
||||||
|
Conn: conn,
|
||||||
|
isTLS: false,
|
||||||
|
tlsCfg: nil,
|
||||||
|
tlsConn: nil,
|
||||||
|
noTlsReader: conn,
|
||||||
|
isOriginal: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
||||||
|
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{config},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewClientTlsConn(conn, tlsConfig)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user