rewrite program
This commit is contained in:
parent
0e2f91eee2
commit
50aef48d49
1657
addon_test.go
Normal file
1657
addon_test.go
Normal file
File diff suppressed because it is too large
Load Diff
197
benchmark_test.go
Normal file
197
benchmark_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkGetRequest(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
b.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
resp.Body().String()
|
||||
resp.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetRequestWithHeaders(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL,
|
||||
WithHeader("X-Custom", "value"),
|
||||
WithUserAgent("BenchmarkAgent"))
|
||||
if err != nil {
|
||||
b.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
resp.Body().String()
|
||||
resp.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPostRequest(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := []byte("test data for benchmark")
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Post(server.URL, WithBody(testData))
|
||||
if err != nil {
|
||||
b.Fatalf("Post() error: %v", err)
|
||||
}
|
||||
resp.Body().String()
|
||||
resp.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkJSONRequest(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
type TestData struct {
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
|
||||
data := TestData{Name: "test", Value: 123}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Post(server.URL, WithJSON(data))
|
||||
if err != nil {
|
||||
b.Fatalf("Post() error: %v", err)
|
||||
}
|
||||
var result map[string]string
|
||||
resp.Body().JSON(&result)
|
||||
resp.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConcurrentRequests(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
b.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
resp.Body().String()
|
||||
resp.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRequestClone(b *testing.B) {
|
||||
req := NewSimpleRequest("https://example.com", "GET").
|
||||
SetHeader("X-Custom", "value").
|
||||
AddQuery("key", "value")
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = req.Clone()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClientCreation(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = NewClientNoErr()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRequestCreation(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = NewSimpleRequest("https://example.com", "GET")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkResponseBodyRead(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("test response data"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pre-fetch response
|
||||
resp, _ := Get(server.URL, WithAutoFetch(true))
|
||||
defer resp.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = resp.Body().String()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDifferentResponseSizes(b *testing.B) {
|
||||
sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
|
||||
|
||||
for _, size := range sizes {
|
||||
responseData := make([]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
responseData[i] = 'A'
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(responseData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
b.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
resp.Body().Bytes()
|
||||
resp.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
145
body_test.go
Normal file
145
body_test.go
Normal file
@ -0,0 +1,145 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestBodyBytes(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := []byte("test data")
|
||||
req := NewSimpleRequest(server.URL, "POST").SetBody(testData)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().Bytes()
|
||||
if !bytes.Equal(body, testData) {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBodyString(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := "test string data"
|
||||
req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != testData {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBodyReader(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
w.Write(body)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := "test reader data"
|
||||
reader := strings.NewReader(testData)
|
||||
req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != testData {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Content-Type") != ContentTypeJSON {
|
||||
t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON)
|
||||
}
|
||||
|
||||
var data map[string]string
|
||||
json.NewDecoder(r.Body).Decode(&data)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := map[string]string{
|
||||
"name": "John",
|
||||
"email": "john@example.com",
|
||||
}
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").SetJSON(testData)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]string
|
||||
resp.Body().JSON(&result)
|
||||
|
||||
if result["name"] != testData["name"] {
|
||||
t.Errorf("name = %v; want %v", result["name"], testData["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestFormData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
data := make(map[string]string)
|
||||
for k, v := range r.Form {
|
||||
if len(v) > 0 {
|
||||
data[k] = v[0]
|
||||
}
|
||||
}
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").
|
||||
AddFormData("name", "John").
|
||||
AddFormData("email", "john@example.com")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]string
|
||||
resp.Body().JSON(&result)
|
||||
|
||||
if result["name"] != "John" {
|
||||
t.Errorf("name = %v; want John", result["name"])
|
||||
}
|
||||
if result["email"] != "john@example.com" {
|
||||
t.Errorf("email = %v; want john@example.com", result["email"])
|
||||
}
|
||||
}
|
||||
324
client.go
Normal file
324
client.go
Normal file
@ -0,0 +1,324 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client HTTP 客户端封装
|
||||
type Client struct {
|
||||
client *http.Client
|
||||
opts []RequestOpt
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient 创建新的 Client
|
||||
func NewClient(opts ...RequestOpt) (*Client, error) {
|
||||
// 创建基础 Transport
|
||||
baseTransport := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &Transport{base: baseTransport},
|
||||
//Timeout: DefaultTimeout,
|
||||
}
|
||||
|
||||
// 应用选项(如果有)
|
||||
if len(opts) > 0 {
|
||||
// 创建临时请求以应用选项
|
||||
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err, "create client")
|
||||
}
|
||||
|
||||
/*
|
||||
// 如果选项中有自定义配置,应用到 httpClient
|
||||
if req.config.Network.Timeout > 0 {
|
||||
httpClient.Timeout = req.config.Network.Timeout
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
// 如果有自定义 Transport
|
||||
if req.config.CustomTransport && req.config.Transport != nil {
|
||||
httpClient.Transport = &Transport{base: req.config.Transport}
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: httpClient,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewClientNoErr 创建新的 Client(忽略错误)
|
||||
func NewClientNoErr(opts ...RequestOpt) *Client {
|
||||
client, _ := NewClient(opts...)
|
||||
if client == nil {
|
||||
client = &Client{
|
||||
client: &http.Client{},
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// NewClientFromHTTP 从 http.Client 创建 Client
|
||||
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
|
||||
if httpClient == nil {
|
||||
return nil, ErrNilClient
|
||||
}
|
||||
|
||||
// 确保 Transport 是我们的自定义类型
|
||||
if httpClient.Transport == nil {
|
||||
httpClient.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
} else {
|
||||
switch t := httpClient.Transport.(type) {
|
||||
case *Transport:
|
||||
// 已经是我们的类型
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{}
|
||||
}
|
||||
case *http.Transport:
|
||||
// 包装标准 Transport
|
||||
httpClient.Transport = &Transport{
|
||||
base: t,
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported transport type: %T", t)
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HTTPClient 获取底层 http.Client
|
||||
func (c *Client) HTTPClient() *http.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
// RequestOptions 获取默认选项(返回副本)
|
||||
func (c *Client) RequestOptions() []RequestOpt {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
opts := make([]RequestOpt, len(c.opts))
|
||||
copy(opts, c.opts)
|
||||
return opts
|
||||
}
|
||||
|
||||
// SetOptions 设置默认选项
|
||||
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
|
||||
c.mu.Lock()
|
||||
c.opts = opts
|
||||
c.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// AddOptions 追加默认选项
|
||||
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
|
||||
c.mu.Lock()
|
||||
c.opts = append(c.opts, opts...)
|
||||
c.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// Clone 克隆 Client(深拷贝)
|
||||
func (c *Client) Clone() *Client {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// 克隆 Transport
|
||||
var transport http.RoundTripper
|
||||
if c.client.Transport != nil {
|
||||
switch t := c.client.Transport.(type) {
|
||||
case *Transport:
|
||||
transport = &Transport{
|
||||
base: t.base.Clone(),
|
||||
}
|
||||
case *http.Transport:
|
||||
transport = t.Clone()
|
||||
default:
|
||||
transport = c.client.Transport
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: c.client.CheckRedirect,
|
||||
Jar: c.client.Jar,
|
||||
Timeout: c.client.Timeout,
|
||||
},
|
||||
opts: append([]RequestOpt(nil), c.opts...),
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaultTLSConfig 设置默认 TLS 配置
|
||||
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
||||
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||
transport.mu.Lock()
|
||||
if transport.base.TLSClientConfig == nil {
|
||||
transport.base.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.base.TLSClientConfig = tlsConfig
|
||||
transport.mu.Unlock()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
|
||||
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
|
||||
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||
transport.mu.Lock()
|
||||
if transport.base.TLSClientConfig == nil {
|
||||
transport.base.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||
transport.mu.Unlock()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// DisableRedirect 禁用重定向
|
||||
func (c *Client) DisableRedirect() *Client {
|
||||
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// EnableRedirect 启用重定向
|
||||
func (c *Client) EnableRedirect() *Client {
|
||||
c.client.CheckRedirect = nil
|
||||
return c
|
||||
}
|
||||
|
||||
// NewRequest 创建新请求
|
||||
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
return c.NewRequestWithContext(context.Background(), url, method, opts...)
|
||||
}
|
||||
|
||||
// NewRequestWithContext 创建新请求(带 context)
|
||||
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
// 合并 Client 级别和请求级别的选项
|
||||
c.mu.RLock()
|
||||
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
|
||||
c.mu.RUnlock()
|
||||
|
||||
req, err := newRequest(ctx, url, method, allOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.client = c
|
||||
req.httpClient = c.client
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Get 发送 GET 请求
|
||||
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodGet, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Post 发送 POST 请求
|
||||
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodPost, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Put 发送 PUT 请求
|
||||
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodPut, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Delete 发送 DELETE 请求
|
||||
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodDelete, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Head 发送 HEAD 请求
|
||||
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodHead, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Patch 发送 PATCH 请求
|
||||
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodPatch, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Options 发送 OPTIONS 请求
|
||||
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := c.NewRequest(url, http.MethodOptions, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
|
||||
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
|
||||
}
|
||||
|
||||
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
||||
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
|
||||
if err != nil {
|
||||
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
|
||||
return &Request{
|
||||
ctx: ctx,
|
||||
url: url,
|
||||
method: method,
|
||||
err: err,
|
||||
config: &RequestConfig{
|
||||
Headers: make(http.Header),
|
||||
Queries: make(map[string][]string),
|
||||
Body: BodyConfig{
|
||||
FormData: make(map[string][]string),
|
||||
},
|
||||
},
|
||||
client: c,
|
||||
httpClient: c.client,
|
||||
autoFetch: DefaultFetchRespBody,
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
223
client_test.go
Normal file
223
client_test.go
Normal file
@ -0,0 +1,223 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewClient() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientNoErr(t *testing.T) {
|
||||
client := NewClientNoErr()
|
||||
if client == nil {
|
||||
t.Fatal("NewClientNoErr() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientFromHTTP(t *testing.T) {
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
client, err := NewClientFromHTTP(httpClient)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientFromHTTP() error: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("NewClientFromHTTP() returned nil")
|
||||
}
|
||||
|
||||
// Test with nil client
|
||||
_, err = NewClientFromHTTP(nil)
|
||||
if err == nil {
|
||||
t.Error("NewClientFromHTTP(nil) should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientOptions(t *testing.T) {
|
||||
client := NewClientNoErr()
|
||||
|
||||
// Set options
|
||||
client.SetOptions(WithTimeout(5 * time.Second))
|
||||
opts := client.RequestOptions()
|
||||
if len(opts) != 1 {
|
||||
t.Errorf("RequestOptions() length = %v; want 1", len(opts))
|
||||
}
|
||||
|
||||
// Add options
|
||||
client.AddOptions(WithUserAgent("TestAgent"))
|
||||
opts = client.RequestOptions()
|
||||
if len(opts) != 2 {
|
||||
t.Errorf("RequestOptions() length = %v; want 2", len(opts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientClone(t *testing.T) {
|
||||
client := NewClientNoErr(WithTimeout(5 * time.Second))
|
||||
cloned := client.Clone()
|
||||
|
||||
if cloned == nil {
|
||||
t.Fatal("Clone() returned nil")
|
||||
}
|
||||
|
||||
// 修改克隆的 client
|
||||
cloned.SetOptions(WithTimeout(10 * time.Second))
|
||||
|
||||
origOpts := client.RequestOptions()
|
||||
clonedOpts := cloned.RequestOptions()
|
||||
|
||||
// 原 client 应该还是 1 个选项
|
||||
if len(origOpts) != 1 {
|
||||
t.Errorf("Original client options = %v; want 1", len(origOpts))
|
||||
}
|
||||
|
||||
// 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖)
|
||||
if len(clonedOpts) != 1 {
|
||||
t.Errorf("Cloned client options = %v; want 1", len(clonedOpts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientHTTPMethods(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.Method))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method func(string, ...RequestOpt) (*Response, error)
|
||||
want string
|
||||
}{
|
||||
{"GET", client.Get, "GET"},
|
||||
{"POST", client.Post, "POST"},
|
||||
{"PUT", client.Put, "PUT"},
|
||||
{"DELETE", client.Delete, "DELETE"},
|
||||
{"PATCH", client.Patch, "PATCH"},
|
||||
{"HEAD", client.Head, "HEAD"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := tt.method(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("%s() error: %v", tt.name, err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if tt.want != "HEAD" {
|
||||
body, _ := resp.Body().String()
|
||||
if body != tt.want {
|
||||
t.Errorf("Body = %v; want %v", body, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRedirect(t *testing.T) {
|
||||
redirectCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if redirectCount < 2 {
|
||||
redirectCount++
|
||||
http.Redirect(w, r, "/redirected", http.StatusFound)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("final"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test with redirect enabled (default)
|
||||
client := NewClientNoErr()
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
|
||||
if redirectCount != 2 {
|
||||
t.Errorf("Redirect count = %v; want 2", redirectCount)
|
||||
}
|
||||
|
||||
// Test with redirect disabled
|
||||
redirectCount = 0
|
||||
client.DisableRedirect()
|
||||
resp2, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp2.Close()
|
||||
|
||||
if resp2.StatusCode != http.StatusFound {
|
||||
t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTLSConfig(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Without skip verify (should fail with self-signed cert)
|
||||
client := NewClientNoErr()
|
||||
_, err := client.Get(server.URL)
|
||||
if err == nil {
|
||||
t.Error("Expected TLS error with self-signed cert, got nil")
|
||||
}
|
||||
|
||||
// With skip verify
|
||||
client.SetDefaultSkipTLSVerify(true)
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() with skip verify error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNewSimpleRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr()
|
||||
|
||||
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v"))
|
||||
if req == nil {
|
||||
t.Fatal("NewSimpleRequest returned nil")
|
||||
}
|
||||
if req.Err() != nil {
|
||||
t.Fatalf("NewSimpleRequest err: %v", req.Err())
|
||||
}
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Body = %v; want OK", body)
|
||||
}
|
||||
}
|
||||
111
concurrent_test.go
Normal file
111
concurrent_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
var counter int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt64(&counter, 1)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr()
|
||||
concurrency := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Get() error: %v", err)
|
||||
return
|
||||
}
|
||||
resp.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if atomic.LoadInt64(&counter) != int64(concurrency) {
|
||||
t.Errorf("counter = %v; want %v", counter, concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentClientModification(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(200)
|
||||
|
||||
// 100 goroutines reading
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Get() error: %v", err)
|
||||
return
|
||||
}
|
||||
resp.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
// 100 goroutines modifying options
|
||||
for i := 0; i < 100; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
if i%2 == 0 {
|
||||
client.AddOptions(WithTimeout(5 * time.Second))
|
||||
} else {
|
||||
_ = client.RequestOptions()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConcurrentRequestClone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(50)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
cloned := baseReq.Clone()
|
||||
// 修复:使用有效的 header 值
|
||||
cloned.SetHeader("X-Index", fmt.Sprintf("%d", i))
|
||||
resp, err := cloned.Do()
|
||||
if err != nil {
|
||||
t.Errorf("Do() error: %v", err)
|
||||
return
|
||||
}
|
||||
resp.Close()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
149
context.go
Normal file
149
context.go
Normal file
@ -0,0 +1,149 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// contextKey 私有的 context key 类型(防止冲突)
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
ctxKeyTransport contextKey = iota
|
||||
ctxKeyTLSConfig
|
||||
ctxKeyProxy
|
||||
ctxKeyCustomIP
|
||||
ctxKeyCustomDNS
|
||||
ctxKeyDialTimeout
|
||||
ctxKeyTimeout
|
||||
ctxKeyLookupIP
|
||||
ctxKeyDialFunc
|
||||
)
|
||||
|
||||
// RequestContext 从 context 中提取的请求配置
|
||||
type RequestContext struct {
|
||||
Transport *http.Transport
|
||||
TLSConfig *tls.Config
|
||||
Proxy string
|
||||
CustomIP []string
|
||||
CustomDNS []string
|
||||
DialTimeout time.Duration
|
||||
Timeout time.Duration
|
||||
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
||||
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// getRequestContext 从 context 中提取请求配置
|
||||
func getRequestContext(ctx context.Context) *RequestContext {
|
||||
rc := &RequestContext{}
|
||||
|
||||
if v := ctx.Value(ctxKeyTransport); v != nil {
|
||||
rc.Transport, _ = v.(*http.Transport)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
||||
rc.TLSConfig, _ = v.(*tls.Config)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyProxy); v != nil {
|
||||
rc.Proxy, _ = v.(string)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
||||
rc.CustomIP, _ = v.([]string)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
||||
rc.CustomDNS, _ = v.([]string)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
||||
rc.DialTimeout, _ = v.(time.Duration)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
||||
rc.Timeout, _ = v.(time.Duration)
|
||||
}
|
||||
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
||||
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
||||
}
|
||||
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
||||
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
||||
}
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// needsDynamicTransport 判断是否需要动态 Transport
|
||||
func needsDynamicTransport(rc *RequestContext) bool {
|
||||
return rc.Transport != nil ||
|
||||
rc.TLSConfig != nil ||
|
||||
rc.Proxy != "" ||
|
||||
rc.DialFn != nil ||
|
||||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
||||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
|
||||
len(rc.CustomIP) > 0 ||
|
||||
len(rc.CustomDNS) > 0 ||
|
||||
rc.LookupIPFn != nil
|
||||
}
|
||||
|
||||
// injectRequestConfig 将请求配置注入到 context
|
||||
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
|
||||
execCtx := ctx
|
||||
|
||||
// 处理 TLS 配置
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if config.TLS.Config != nil {
|
||||
tlsConfig = config.TLS.Config.Clone()
|
||||
if config.TLS.SkipVerify {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
} else if config.TLS.SkipVerify {
|
||||
tlsConfig = &tls.Config{
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
|
||||
}
|
||||
|
||||
// 注入代理
|
||||
if config.Network.Proxy != "" {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
|
||||
}
|
||||
|
||||
// 注入自定义 IP
|
||||
if len(config.DNS.CustomIP) > 0 {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
|
||||
}
|
||||
|
||||
// 注入自定义 DNS
|
||||
if len(config.DNS.CustomDNS) > 0 {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
||||
}
|
||||
|
||||
// 总是注入 DialTimeout 和 Timeout(与原始代码一致)
|
||||
if config.Network.DialTimeout > 0 {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
||||
}
|
||||
if config.Network.Timeout > 0 {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
|
||||
}
|
||||
|
||||
// 注入 DNS 解析函数
|
||||
if config.DNS.LookupFunc != nil {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
|
||||
}
|
||||
|
||||
// 注入 Dial 函数
|
||||
if config.Network.DialFunc != nil {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
|
||||
}
|
||||
|
||||
// 注入自定义 Transport
|
||||
if config.CustomTransport && config.Transport != nil {
|
||||
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
|
||||
}
|
||||
|
||||
return execCtx
|
||||
}
|
||||
198
curl_default.go
198
curl_default.go
@ -1,198 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
|
||||
HEADER_FORM_DATA = `multipart/form-data`
|
||||
HEADER_JSON = `application/json`
|
||||
HEADER_PLAIN = `text/plain`
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultDialTimeout = 5 * time.Second
|
||||
DefaultTimeout = 10 * time.Second
|
||||
DefaultFetchRespBody = false
|
||||
DefaultHttpClient = NewHttpClientNoErr()
|
||||
)
|
||||
|
||||
func UrlEncodeRaw(str string) string {
|
||||
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
||||
return strs
|
||||
}
|
||||
|
||||
func UrlEncode(str string) string {
|
||||
return url.QueryEscape(str)
|
||||
}
|
||||
|
||||
func UrlDecode(str string) (string, error) {
|
||||
return url.QueryUnescape(str)
|
||||
}
|
||||
|
||||
func BuildQuery(queryData map[string]string) string {
|
||||
query := url.Values{}
|
||||
for k, v := range queryData {
|
||||
query.Add(k, v)
|
||||
}
|
||||
return query.Encode()
|
||||
}
|
||||
|
||||
// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string,
|
||||
// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests,
|
||||
// where the server expects the request body to be URL-encoded form data.
|
||||
//
|
||||
// Parameters:
|
||||
// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request.
|
||||
//
|
||||
// Returns:
|
||||
// A byte slice representing the URL-encoded form data.
|
||||
func BuildPostForm(queryMap map[string]string) []byte {
|
||||
return []byte(BuildQuery(queryMap))
|
||||
}
|
||||
|
||||
func Get(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "GET", opts...).Do()
|
||||
}
|
||||
|
||||
func Post(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "POST", opts...).Do()
|
||||
}
|
||||
|
||||
func Options(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "OPTIONS", opts...).Do()
|
||||
}
|
||||
|
||||
func Put(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PUT", opts...).Do()
|
||||
}
|
||||
|
||||
func Delete(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "DELETE", opts...).Do()
|
||||
}
|
||||
|
||||
func Head(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "HEAD", opts...).Do()
|
||||
}
|
||||
|
||||
func Patch(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PATCH", opts...).Do()
|
||||
}
|
||||
|
||||
func Trace(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "TRACE", opts...).Do()
|
||||
}
|
||||
|
||||
func Connect(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequestWithClient(DefaultHttpClient, uri, "CONNECT", opts...).Do()
|
||||
}
|
||||
|
||||
func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
func DefaultDialFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||
var lastErr error
|
||||
var addrs []string
|
||||
if dialFn, ok := ctx.Value("dialFunc").(func(context.Context, string, string) (net.Conn, error)); ok {
|
||||
if dialFn != nil {
|
||||
return dialFn(ctx, netType, addr)
|
||||
}
|
||||
}
|
||||
customIP, ok := ctx.Value("customIP").([]string)
|
||||
if !ok {
|
||||
customIP = nil
|
||||
}
|
||||
dialTimeout, ok := ctx.Value("dialTimeout").(time.Duration)
|
||||
if !ok {
|
||||
dialTimeout = DefaultDialTimeout
|
||||
}
|
||||
timeout, ok := ctx.Value("timeout").(time.Duration)
|
||||
if !ok {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
lookUpIPfn, ok := ctx.Value("lookUpIP").(func(context.Context, string) ([]net.IPAddr, error))
|
||||
if !ok {
|
||||
lookUpIPfn = net.DefaultResolver.LookupIPAddr
|
||||
}
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxy, ok := ctx.Value("proxy").(string)
|
||||
if !ok {
|
||||
proxy = ""
|
||||
}
|
||||
if proxy == "" && len(customIP) > 0 {
|
||||
for _, v := range customIP {
|
||||
ipAddr := net.ParseIP(v)
|
||||
if ipAddr == nil {
|
||||
return nil, fmt.Errorf("invalid custom ip: %s", customIP)
|
||||
}
|
||||
tmpAddr := net.JoinHostPort(v, port)
|
||||
addrs = append(addrs, tmpAddr)
|
||||
}
|
||||
} else {
|
||||
ipLists, err := lookUpIPfn(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range ipLists {
|
||||
tmpAddr := net.JoinHostPort(v.String(), port)
|
||||
addrs = append(addrs, tmpAddr)
|
||||
}
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
c, err := net.DialTimeout(netType, addr, dialTimeout)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if timeout != 0 {
|
||||
err = c.SetDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func DefaultDialTlsFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||
conn, err := DefaultDialFunc(ctx, netType, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig, ok := ctx.Value("tlsConfig").(*tls.Config)
|
||||
if !ok || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("tlsConfig is not set in context")
|
||||
}
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return nil, fmt.Errorf("tls handshake failed: %w", err)
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func DefaultProxyURL() func(*http.Request) (*url.URL, error) {
|
||||
return func(req *http.Request) (*url.URL, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
proxyURL, ok := req.Context().Value("proxy").(string)
|
||||
if !ok || proxyURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
|
||||
}
|
||||
return parsedURL, nil
|
||||
}
|
||||
}
|
||||
728
curl_test.go
728
curl_test.go
@ -1,728 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUrlEncodeRaw(t *testing.T) {
|
||||
input := "hello world!@#$%^&*()_+-=~`"
|
||||
expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60"
|
||||
result := UrlEncodeRaw(input)
|
||||
if result != expected {
|
||||
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlEncode(t *testing.T) {
|
||||
input := "hello world!@#$%^&*()_+-=~`"
|
||||
expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60`
|
||||
result := UrlEncode(input)
|
||||
if result != expected {
|
||||
t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlDecode(t *testing.T) {
|
||||
input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60"
|
||||
expected := "hello world!@#$%^&*()_+-=~`"
|
||||
result, err := UrlDecode(input)
|
||||
if err != nil {
|
||||
t.Errorf("UrlDecode(%q) returned error: %v", input, err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected)
|
||||
}
|
||||
|
||||
// Test for error case
|
||||
invalidInput := "%zz"
|
||||
_, err = UrlDecode(invalidInput)
|
||||
if err == nil {
|
||||
t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithValidInput(t *testing.T) {
|
||||
input := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
|
||||
expected := []byte("key1=value1&key2=value2")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithEmptyInput(t *testing.T) {
|
||||
input := map[string]string{}
|
||||
|
||||
expected := []byte("")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm_WithNilInput(t *testing.T) {
|
||||
var input map[string]string
|
||||
|
||||
expected := []byte("")
|
||||
|
||||
result := BuildPostForm(input)
|
||||
|
||||
if string(result) != string(expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL, WithSkipTLSVerify(true), WithHeader("hello", "world"), WithUserAgent("hello world"))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostRequest(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
t.Errorf("Expected 'POST', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Post(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodOptions {
|
||||
t.Errorf("Expected 'OPTIONS', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Options(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPut {
|
||||
t.Errorf("Expected 'PUT', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Put(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodDelete {
|
||||
t.Errorf("Expected 'DELETE', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Delete(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodHead {
|
||||
t.Errorf("Expected 'HEAD', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Head(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body == "OK" {
|
||||
t.Errorf("Expected , got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPatch {
|
||||
t.Errorf("Expected 'PATCH', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Patch(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraceRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodTrace {
|
||||
t.Errorf("Expected 'TRACE', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Trace(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectRequestWithValidInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodConnect {
|
||||
t.Errorf("Expected 'CONNECT', got %v", req.Method)
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Connect(server.URL)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
func TestMethodReturnsCorrectValue(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("GET")
|
||||
if req.Method() != "GET" {
|
||||
t.Errorf("Expected 'GET', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodHandlesInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetMethod("我是谁")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("POST")
|
||||
if req.Method() != "POST" {
|
||||
t.Errorf("Expected 'POST', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetMethodNoError("你是谁")
|
||||
if req.Method() != "GET" {
|
||||
t.Errorf("Expected '', got %v", req.Method())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUriReturnsCorrectValue(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
if req.Uri() != "https://example.com" {
|
||||
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriHandlesValidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetUri("https://newexample.com")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if req.Uri() != "https://newexample.com" {
|
||||
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriHandlesInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
err := req.SetUri("://invalidurl")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetUriNoError("https://newexample.com")
|
||||
if req.Uri() != "https://newexample.com" {
|
||||
t.Errorf("Expected 'https://newexample.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) {
|
||||
req := NewReq("https://example.com")
|
||||
req.SetUriNoError("://invalidurl")
|
||||
if req.Uri() != "https://example.com" {
|
||||
t.Errorf("Expected 'https://example.com', got %v", req.Uri())
|
||||
}
|
||||
}
|
||||
|
||||
type postmanReply struct {
|
||||
Args struct {
|
||||
} `json:"args"`
|
||||
Form map[string]string `json:"form"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
var reply postmanReply
|
||||
resp, err := NewReq("https://postman-echo.com/get").
|
||||
AddHeader("hello", "nononmo").
|
||||
SetAutoCalcContentLengthNoError(true).Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Proto)
|
||||
err = resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Body().String())
|
||||
fmt.Println(reply.Headers)
|
||||
fmt.Println(resp.Cookies())
|
||||
}
|
||||
|
||||
type testData struct {
|
||||
name string
|
||||
args *Request
|
||||
want func(*Response) error
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
func headerTestData() []testData {
|
||||
return []testData{
|
||||
{
|
||||
name: "addHeader",
|
||||
args: NewReq("https://postman-echo.com/get").
|
||||
AddHeader("b612", "test-data").
|
||||
AddHeader("b612", "test-header").
|
||||
AddSimpleCookie("b612", "test-cookie").
|
||||
SetHeader("User-Agent", "starnet test"),
|
||||
want: func(resp *Response) error {
|
||||
//fmt.Println(resp.Body().String())
|
||||
if resp == nil {
|
||||
return fmt.Errorf("response is nil")
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
var reply postmanReply
|
||||
err := resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Headers["b612"] != "test-data, test-header" {
|
||||
return fmt.Errorf("header not found")
|
||||
}
|
||||
if reply.Headers["user-agent"] != "starnet test" {
|
||||
return fmt.Errorf("user-agent not found")
|
||||
}
|
||||
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||
return fmt.Errorf("cookie not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "postForm",
|
||||
args: NewSimpleRequest("https://postman-echo.com/post", "POST").
|
||||
AddHeader("b612", "test-data").
|
||||
AddHeader("b612", "test-header").
|
||||
AddSimpleCookie("b612", "test-cookie").
|
||||
SetHeader("User-Agent", "starnet test").
|
||||
//SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
AddFormData("hello", "world").
|
||||
AddFormData("hello2", "world2").
|
||||
SetMethodNoError("POST"),
|
||||
want: func(resp *Response) error {
|
||||
//fmt.Println(resp.Body().String())
|
||||
if resp == nil {
|
||||
return fmt.Errorf("response is nil")
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
var reply postmanReply
|
||||
err := resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply.Headers["b612"] != "test-data, test-header" {
|
||||
return fmt.Errorf("header not found")
|
||||
}
|
||||
if reply.Headers["user-agent"] != "starnet test" {
|
||||
return fmt.Errorf("user-agent not found")
|
||||
}
|
||||
if reply.Headers["cookie"] != "b612=test-cookie" {
|
||||
return fmt.Errorf("cookie not found")
|
||||
}
|
||||
if reply.Form["hello"] != "world" {
|
||||
return fmt.Errorf("form data not found")
|
||||
}
|
||||
if reply.Form["hello2"] != "world2" {
|
||||
return fmt.Errorf("form data not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
func TestCurl(t *testing.T) {
|
||||
for _, tt := range headerTestData() {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Curl(tt.args)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.want != nil {
|
||||
if err := tt.want(got); err != nil {
|
||||
t.Errorf("Curl() = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReqClone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("hello") != "world" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
rw.Write([]byte("hello world failed"))
|
||||
return
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
req = req.Clone()
|
||||
req.AddHeader("ok", "good")
|
||||
resp, err = req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
}
|
||||
|
||||
func TestUploadFile(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("hello") != "world" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
rw.Write([]byte("hello world failed"))
|
||||
return
|
||||
}
|
||||
files, header, err := req.FormFile("666")
|
||||
if err == nil {
|
||||
fmt.Println(header.Filename)
|
||||
fmt.Println(header.Size)
|
||||
fmt.Println(files.Close())
|
||||
}
|
||||
files, header, err = req.FormFile("777")
|
||||
if err == nil {
|
||||
fmt.Println(header.Filename)
|
||||
fmt.Println(header.Size)
|
||||
fmt.Println(files.Close())
|
||||
}
|
||||
files, header, err = req.FormFile("888")
|
||||
if err == nil {
|
||||
fmt.Println(header.Filename)
|
||||
fmt.Println(header.Size)
|
||||
fmt.Println(files.Close())
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||
req.AddFileWithName("666", "./curl.go", "curl.go")
|
||||
req.AddFile("777", "./go.mod")
|
||||
req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html")
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
req = req.Clone()
|
||||
req.AddHeader("ok", "good")
|
||||
|
||||
resp, err = req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
}
|
||||
|
||||
func TestTlsConfig(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("hello") != "world" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
rw.Write([]byte("hello world failed"))
|
||||
return
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
client, err := NewHttpClient(WithSkipTLSVerify(false))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world"))
|
||||
//SetClientSkipVerify(client, true)
|
||||
//req.SetDoRawClient(false)
|
||||
//req.SetDoRawTransport(false)
|
||||
req.SetSkipTLSVerify(true)
|
||||
req.SetProxy("http://127.0.0.1:29992")
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Proto)
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
req = req.Clone()
|
||||
req.AddHeader("ok", "good")
|
||||
resp, err = req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
req = req.Clone()
|
||||
req.SetSkipTLSVerify(false)
|
||||
resp, err = req.Do()
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpPostAndChunked(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodPost {
|
||||
t.Errorf("Expected 'POST', got %v", req.Method)
|
||||
}
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := req.Body.Read(buf)
|
||||
if string(buf[:n]) != "hello world" {
|
||||
t.Errorf("Expected body to be 'hello world', got %s", string(buf[:n]))
|
||||
}
|
||||
|
||||
if req.Header.Get("chunked") == "true" {
|
||||
if req.TransferEncoding[0] != "chunked" {
|
||||
t.Errorf("Expected Transfer-Encoding to be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
|
||||
}
|
||||
} else {
|
||||
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
|
||||
t.Errorf("Expected Transfer-Encoding to not be 'chunked', got %s", req.Header.Get("Transfer-Encoding"))
|
||||
}
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Post(server.URL, WithBytes([]byte("hello world")), WithContentLength(-1), WithHeader("Content-Type", "text/plain"),
|
||||
WithHeader("chunked", "true"))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
resp.Close()
|
||||
|
||||
resp, err = Post(server.URL, WithBytes([]byte("hello world")), WithHeader("Content-Type", "text/plain"),
|
||||
WithHeader("chunked", "false"))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
body = resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
time.Sleep(time.Second * 30)
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
funcList := []func(string, ...RequestOpt) (*Response, error){
|
||||
Get,
|
||||
Post,
|
||||
Put,
|
||||
Delete,
|
||||
Options,
|
||||
Patch,
|
||||
Head,
|
||||
Trace,
|
||||
Connect,
|
||||
}
|
||||
defer server.Close()
|
||||
for i := 1; i < 30; i++ {
|
||||
go func(i int) {
|
||||
old := time.Now()
|
||||
fn := funcList[i%len(funcList)]
|
||||
resp, err := fn(server.URL, WithTimeout(time.Second*time.Duration(i)))
|
||||
if time.Since(old) > time.Second*time.Duration(i+2) || time.Since(old) < time.Second*time.Duration(i) {
|
||||
t.Errorf("timeout not work")
|
||||
}
|
||||
fmt.Println(time.Since(old))
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
resp.CloseAll()
|
||||
} else {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
resp, err := Get(server.URL, WithTimeout(time.Second*60))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
fmt.Println(resp.Body().String())
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigWithClient(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("hello") != "world" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
rw.Write([]byte("hello world failed"))
|
||||
return
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
client, err := NewHttpClient(WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world"))
|
||||
//SetClientSkipVerify(client, true)
|
||||
//req.SetDoRawClient(false)
|
||||
//req.SetDoRawTransport(false)
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Proto)
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
}
|
||||
@ -1,178 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
*http.Client
|
||||
opts []RequestOpt
|
||||
}
|
||||
|
||||
func (c Client) Options() []RequestOpt {
|
||||
return c.opts
|
||||
}
|
||||
|
||||
func (c Client) SetOptions(opts ...RequestOpt) Client {
|
||||
return Client{
|
||||
Client: c.Client,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHttpClient creates a new http.Client with the specified options.
|
||||
func NewHttpClient(opts ...RequestOpt) (Client, error) {
|
||||
req, err := newRequest(context.Background(), "", "", opts...)
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
}
|
||||
defer func() {
|
||||
req = nil
|
||||
}()
|
||||
cl, err := req.HttpClient()
|
||||
return Client{
|
||||
Client: cl,
|
||||
opts: opts,
|
||||
}, err
|
||||
}
|
||||
|
||||
func NewHttpClientNoErr(opts ...RequestOpt) Client {
|
||||
c, _ := NewHttpClient(opts...)
|
||||
return c
|
||||
}
|
||||
|
||||
func NewClientFromHttpClient(httpClient *http.Client) (Client, error) {
|
||||
if httpClient == nil {
|
||||
return Client{}, fmt.Errorf("httpClient cannot be nil")
|
||||
}
|
||||
|
||||
if httpClient.Transport == nil {
|
||||
httpClient.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
} else {
|
||||
switch t := httpClient.Transport.(type) {
|
||||
case *Transport:
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{}
|
||||
}
|
||||
case *http.Transport:
|
||||
httpClient.Transport = &Transport{
|
||||
base: t,
|
||||
}
|
||||
default:
|
||||
return Client{}, fmt.Errorf("unsupported transport type: %T", t)
|
||||
}
|
||||
}
|
||||
return Client{
|
||||
Client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClientFromHttpClientNoError(httpClient *http.Client) Client {
|
||||
return Client{Client: httpClient}
|
||||
}
|
||||
|
||||
// DisableRedirect returns whether the request will disable HTTP redirects.
|
||||
// if true, the request will not follow redirects automatically.
|
||||
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||
// you will get the original response with the redirect status code and Location header.
|
||||
func (c Client) DisableRedirect() bool {
|
||||
return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer()
|
||||
}
|
||||
|
||||
// SetDisableRedirect sets whether the request will disable HTTP redirects.
|
||||
// if true, the request will not follow redirects automatically.
|
||||
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||
// you will get the original response with the redirect status code and Location header.
|
||||
func (c Client) SetDisableRedirect(disableRedirect bool) {
|
||||
if disableRedirect {
|
||||
c.Client.CheckRedirect = DefaultCheckRedirectFunc
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) SetDefaultSkipTLSVerify(skip bool) {
|
||||
if c.Client.Transport == nil {
|
||||
c.Client.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
}
|
||||
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||
if transport.base.TLSClientConfig == nil {
|
||||
transport.base.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||
if transport.TLSClientConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.TLSClientConfig.InsecureSkipVerify = skip
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) {
|
||||
if c.Client.Transport == nil {
|
||||
c.Client.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
}
|
||||
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||
transport.base.TLSClientConfig = tlsConfig
|
||||
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
if c.Client == nil {
|
||||
return nil, fmt.Errorf("http client is nil")
|
||||
}
|
||||
req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...)
|
||||
return req, err
|
||||
}
|
||||
|
||||
func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
if c.Client == nil {
|
||||
return nil, fmt.Errorf("http client is nil")
|
||||
}
|
||||
req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...)
|
||||
return req, err
|
||||
}
|
||||
|
||||
func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||
req, _ := c.NewRequest(url, method, opts...)
|
||||
return req
|
||||
}
|
||||
|
||||
func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||
req, _ := c.NewRequestContext(ctx, url, method, opts...)
|
||||
return req
|
||||
}
|
||||
|
||||
type Transport struct {
|
||||
base *http.Transport
|
||||
}
|
||||
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{}
|
||||
}
|
||||
transport, ok := req.Context().Value("transport").(*http.Transport)
|
||||
if ok && transport != nil {
|
||||
return transport.RoundTrip(req)
|
||||
}
|
||||
proxy, ok := req.Context().Value("proxy").(string)
|
||||
if ok && proxy != "" {
|
||||
tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config)
|
||||
if ok && tlsConfig != nil {
|
||||
tmpTransport := t.base.Clone()
|
||||
tmpTransport.TLSClientConfig = tlsConfig
|
||||
return tmpTransport.RoundTrip(req)
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
@ -1,198 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkGetRequest 测试单个 GET 请求的性能
|
||||
func BenchmarkGetRequest(b *testing.B) {
|
||||
// 创建测试服务器
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// 重置计时器,排除设置代码的影响
|
||||
b.ResetTimer()
|
||||
|
||||
// 报告内存分配情况
|
||||
b.ReportAllocs()
|
||||
|
||||
// 运行基准测试
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
b.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkGetRequestWithHeaders 测试带请求头的 GET 请求性能
|
||||
func BenchmarkGetRequestWithHeaders(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
// 验证请求头
|
||||
if req.Header.Get("hello") != "world" {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL,
|
||||
WithSkipTLSVerify(true),
|
||||
WithHeader("hello", "world"),
|
||||
WithUserAgent("hello world"))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
b.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPostRequest 测试 POST 请求的性能
|
||||
func BenchmarkPostRequest(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
// 读取并返回请求体
|
||||
body := make([]byte, req.ContentLength)
|
||||
req.Body.Read(body)
|
||||
rw.Write(body)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
testData := "This is a test payload for POST request"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Post(server.URL,
|
||||
WithSkipTLSVerify(true),
|
||||
WithBytes([]byte(testData)),
|
||||
WithContentType("text/plain"))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != testData {
|
||||
b.Errorf("Expected %s, got %v", testData, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrentRequests 测试并发请求性能
|
||||
func BenchmarkConcurrentRequests(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// 运行并发基准测试
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
b.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkMemoryUsage 专门测试内存使用情况
|
||||
func BenchmarkMemoryUsage(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write([]byte(`OK`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// 禁用默认的测试时间,只关注内存分配
|
||||
b.ReportAllocs()
|
||||
|
||||
var memStatsStart, memStatsEnd runtime.MemStats
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&memStatsStart)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().String()
|
||||
if body != "OK" {
|
||||
b.Errorf("Expected OK, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&memStatsEnd)
|
||||
|
||||
// 计算每次操作的平均内存分配
|
||||
allocsPerOp := float64(memStatsEnd.Mallocs-memStatsStart.Mallocs) / float64(b.N)
|
||||
bytesPerOp := float64(memStatsEnd.TotalAlloc-memStatsStart.TotalAlloc) / float64(b.N)
|
||||
|
||||
b.ReportMetric(allocsPerOp, "allocs/op")
|
||||
b.ReportMetric(bytesPerOp, "bytes/op")
|
||||
}
|
||||
|
||||
// BenchmarkDifferentResponseSizes 测试不同响应大小的性能
|
||||
func BenchmarkDifferentResponseSizes(b *testing.B) {
|
||||
// 测试不同大小的响应
|
||||
responseSizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
|
||||
|
||||
for _, size := range responseSizes {
|
||||
// 生成指定大小的响应数据
|
||||
responseData := make([]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
responseData[i] = 'A'
|
||||
}
|
||||
|
||||
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Write(responseData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
resp, err := Get(server.URL, WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
b.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
body := resp.Body().Bytes()
|
||||
if len(body) != size {
|
||||
b.Errorf("Expected size %d, got %d", size, len(body))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
147
defaults.go
Normal file
147
defaults.go
Normal file
@ -0,0 +1,147 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultClient *Client
|
||||
defaultHTTPClient *http.Client
|
||||
defaultClientOnce sync.Once
|
||||
defaultHTTPOnce sync.Once
|
||||
|
||||
defaultMu sync.RWMutex
|
||||
)
|
||||
|
||||
// DefaultClient 获取默认 Client(单例)
|
||||
func DefaultClient() *Client {
|
||||
defaultMu.RLock()
|
||||
if defaultClient != nil {
|
||||
c := defaultClient
|
||||
defaultMu.RUnlock()
|
||||
return c
|
||||
}
|
||||
defaultMu.RUnlock()
|
||||
|
||||
defaultClientOnce.Do(func() {
|
||||
c := NewClientNoErr()
|
||||
defaultMu.Lock()
|
||||
defaultClient = c
|
||||
defaultMu.Unlock()
|
||||
})
|
||||
|
||||
defaultMu.RLock()
|
||||
c := defaultClient
|
||||
defaultMu.RUnlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// DefaultHTTPClient 获取默认 http.Client(单例)
|
||||
func DefaultHTTPClient() *http.Client {
|
||||
defaultMu.RLock()
|
||||
if defaultHTTPClient != nil {
|
||||
c := defaultHTTPClient
|
||||
defaultMu.RUnlock()
|
||||
return c
|
||||
}
|
||||
defaultMu.RUnlock()
|
||||
|
||||
defaultHTTPOnce.Do(func() {
|
||||
c := &http.Client{
|
||||
Transport: &Transport{
|
||||
base: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
Timeout: 0, // 由请求级控制超时
|
||||
}
|
||||
defaultMu.Lock()
|
||||
defaultHTTPClient = c
|
||||
defaultMu.Unlock()
|
||||
})
|
||||
|
||||
defaultMu.RLock()
|
||||
c := defaultHTTPClient
|
||||
defaultMu.RUnlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// SetDefaultClient 设置默认 Client
|
||||
func SetDefaultClient(client *Client) {
|
||||
defaultMu.Lock()
|
||||
defer defaultMu.Unlock()
|
||||
|
||||
defaultClient = client
|
||||
// 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖
|
||||
defaultClientOnce.Do(func() {})
|
||||
}
|
||||
|
||||
// SetDefaultHTTPClient 设置默认 http.Client
|
||||
func SetDefaultHTTPClient(client *http.Client) {
|
||||
defaultMu.Lock()
|
||||
defer defaultMu.Unlock()
|
||||
|
||||
defaultHTTPClient = client
|
||||
// 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖
|
||||
defaultHTTPOnce.Do(func() {})
|
||||
}
|
||||
|
||||
// Get 发送 GET 请求(使用默认 Client)
|
||||
func Get(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Get(url, opts...)
|
||||
}
|
||||
|
||||
// Post 发送 POST 请求(使用默认 Client)
|
||||
func Post(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Post(url, opts...)
|
||||
}
|
||||
|
||||
// Put 发送 PUT 请求(使用默认 Client)
|
||||
func Put(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Put(url, opts...)
|
||||
}
|
||||
|
||||
// Delete 发送 DELETE 请求(使用默认 Client)
|
||||
func Delete(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Delete(url, opts...)
|
||||
}
|
||||
|
||||
// Head 发送 HEAD 请求(使用默认 Client)
|
||||
func Head(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Head(url, opts...)
|
||||
}
|
||||
|
||||
// Patch 发送 PATCH 请求(使用默认 Client)
|
||||
func Patch(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Patch(url, opts...)
|
||||
}
|
||||
|
||||
// Options 发送 OPTIONS 请求(使用默认 Client)
|
||||
func Options(url string, opts ...RequestOpt) (*Response, error) {
|
||||
return DefaultClient().Options(url, opts...)
|
||||
}
|
||||
|
||||
// Trace 发送 TRACE 请求(使用默认 Client)
|
||||
func Trace(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
|
||||
// Connect 发送 CONNECT 请求(使用默认 Client)
|
||||
func Connect(url string, opts ...RequestOpt) (*Response, error) {
|
||||
req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req.Do()
|
||||
}
|
||||
148
dialer.go
Normal file
148
dialer.go
Normal file
@ -0,0 +1,148 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
||||
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 提取配置
|
||||
reqCtx := getRequestContext(ctx)
|
||||
|
||||
dialTimeout := reqCtx.DialTimeout
|
||||
if dialTimeout == 0 {
|
||||
dialTimeout = DefaultDialTimeout
|
||||
}
|
||||
|
||||
timeout := reqCtx.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
// 解析地址
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, wrapError(err, "split host port")
|
||||
}
|
||||
|
||||
// 获取 IP 地址列表
|
||||
var addrs []string
|
||||
|
||||
// 优先级1:直接指定的 IP
|
||||
if len(reqCtx.CustomIP) > 0 {
|
||||
for _, ip := range reqCtx.CustomIP {
|
||||
addrs = append(addrs, net.JoinHostPort(ip, port))
|
||||
}
|
||||
} else {
|
||||
// 优先级2:DNS 解析
|
||||
var ipAddrs []net.IPAddr
|
||||
|
||||
// 使用自定义解析函数
|
||||
if reqCtx.LookupIPFn != nil {
|
||||
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
|
||||
} else if len(reqCtx.CustomDNS) > 0 {
|
||||
// 使用自定义 DNS 服务器
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var lastErr error
|
||||
for _, dnsServer := range reqCtx.CustomDNS {
|
||||
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
return nil, lastErr
|
||||
},
|
||||
}
|
||||
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
|
||||
} else {
|
||||
// 使用默认解析器
|
||||
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, wrapError(err, "lookup ip")
|
||||
}
|
||||
|
||||
for _, ipAddr := range ipAddrs {
|
||||
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试连接所有地址
|
||||
var lastErr error
|
||||
for _, addr := range addrs {
|
||||
conn, err := net.DialTimeout(network, addr, dialTimeout)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置总超时
|
||||
if timeout > 0 {
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, wrapError(lastErr, "dial all addresses failed")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no addresses to dial")
|
||||
}
|
||||
|
||||
// defaultDialTLSFunc 默认 TLS Dial 函数
|
||||
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 先建立 TCP 连接
|
||||
conn, err := defaultDialFunc(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 提取 TLS 配置
|
||||
reqCtx := getRequestContext(ctx)
|
||||
tlsConfig := reqCtx.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
// 执行 TLS 握手
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, wrapError(err, "tls handshake")
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// defaultProxyFunc 默认代理函数
|
||||
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
|
||||
reqCtx := getRequestContext(req.Context())
|
||||
if reqCtx.Proxy == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
proxyURL, err := url.Parse(reqCtx.Proxy)
|
||||
if err != nil {
|
||||
return nil, wrapError(err, "parse proxy url")
|
||||
}
|
||||
|
||||
return proxyURL, nil
|
||||
}
|
||||
|
||||
*/
|
||||
103
dns_test.go
Normal file
103
dns_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestCustomIP(t *testing.T) {
|
||||
customIPs := []string{"1.2.3.4", "5.6.7.8"}
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetCustomIP(customIPs)
|
||||
|
||||
if len(req.config.DNS.CustomIP) != 2 {
|
||||
t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP))
|
||||
}
|
||||
|
||||
for i, ip := range req.config.DNS.CustomIP {
|
||||
if ip != customIPs[i] {
|
||||
t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCustomIPInvalid(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetCustomIP([]string{"invalid-ip"})
|
||||
|
||||
if req.Err() == nil {
|
||||
t.Error("Expected error for invalid IP, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCustomDNS(t *testing.T) {
|
||||
dnsServers := []string{"8.8.8.8", "1.1.1.1"}
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetCustomDNS(dnsServers)
|
||||
|
||||
if len(req.config.DNS.CustomDNS) != 2 {
|
||||
t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCustomDNSInvalid(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetCustomDNS([]string{"invalid-dns"})
|
||||
|
||||
if req.Err() == nil {
|
||||
t.Error("Expected error for invalid DNS, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLookupFunc(t *testing.T) {
|
||||
called := false
|
||||
lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
called = true
|
||||
return []net.IPAddr{
|
||||
{IP: net.ParseIP("1.2.3.4")},
|
||||
}, nil
|
||||
}
|
||||
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetLookupFunc(lookupFunc)
|
||||
|
||||
if req.config.DNS.LookupFunc == nil {
|
||||
t.Error("LookupFunc not set")
|
||||
}
|
||||
|
||||
// Call the function to verify it works
|
||||
ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com")
|
||||
if err != nil {
|
||||
t.Errorf("LookupFunc error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("LookupFunc was not called")
|
||||
}
|
||||
if len(ips) != 1 {
|
||||
t.Errorf("IPs length = %v; want 1", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPriority(t *testing.T) {
|
||||
// CustomIP should have highest priority
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetCustomIP([]string{"1.2.3.4"}).
|
||||
SetCustomDNS([]string{"8.8.8.8"}).
|
||||
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||
return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil
|
||||
})
|
||||
|
||||
// CustomIP should be set
|
||||
if len(req.config.DNS.CustomIP) == 0 {
|
||||
t.Error("CustomIP should be set")
|
||||
}
|
||||
|
||||
// Others should also be set (but CustomIP takes priority in actual use)
|
||||
if len(req.config.DNS.CustomDNS) == 0 {
|
||||
t.Error("CustomDNS should be set")
|
||||
}
|
||||
if req.config.DNS.LookupFunc == nil {
|
||||
t.Error("LookupFunc should be set")
|
||||
}
|
||||
}
|
||||
58
errors.go
Normal file
58
errors.go
Normal file
@ -0,0 +1,58 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidMethod 无效的 HTTP 方法
|
||||
ErrInvalidMethod = errors.New("starnet: invalid HTTP method")
|
||||
|
||||
// ErrInvalidURL 无效的 URL
|
||||
ErrInvalidURL = errors.New("starnet: invalid URL")
|
||||
|
||||
// ErrInvalidIP 无效的 IP 地址
|
||||
ErrInvalidIP = errors.New("starnet: invalid IP address")
|
||||
|
||||
// ErrInvalidDNS 无效的 DNS 服务器
|
||||
ErrInvalidDNS = errors.New("starnet: invalid DNS server")
|
||||
|
||||
// ErrNilClient HTTP Client 为 nil
|
||||
ErrNilClient = errors.New("starnet: http client is nil")
|
||||
|
||||
// ErrNilReader Reader 为 nil
|
||||
ErrNilReader = errors.New("starnet: reader is nil")
|
||||
|
||||
// ErrFileNotFound 文件不存在
|
||||
ErrFileNotFound = errors.New("starnet: file not found")
|
||||
|
||||
// ErrRequestNotPrepared 请求未准备好
|
||||
ErrRequestNotPrepared = errors.New("starnet: request not prepared")
|
||||
|
||||
// ErrBodyAlreadyConsumed Body 已被消费
|
||||
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
|
||||
)
|
||||
|
||||
// wrapError 包装错误,添加上下文信息
|
||||
func wrapError(err error, format string, args ...interface{}) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
return fmt.Errorf("%s: %w", msg, err)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNilConn indicates a nil net.Conn argument.
|
||||
ErrNilConn = errors.New("starnet: nil connection")
|
||||
|
||||
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
|
||||
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
|
||||
|
||||
// ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP.
|
||||
ErrNotTLS = errors.New("starnet: connection is not TLS")
|
||||
|
||||
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
|
||||
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
|
||||
)
|
||||
200
example_test.go
Normal file
200
example_test.go
Normal file
@ -0,0 +1,200 @@
|
||||
package starnet_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
||||
"b612.me/starnet"
|
||||
)
|
||||
|
||||
func ExampleGet() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello, World!"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := starnet.Get(server.URL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
fmt.Println(body)
|
||||
// Output: Hello, World!
|
||||
}
|
||||
|
||||
func ExamplePost() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Posted"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := starnet.Post(server.URL,
|
||||
starnet.WithBodyString("test data"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
fmt.Println(body)
|
||||
// Output: Posted
|
||||
}
|
||||
|
||||
func ExampleNewSimpleRequest() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := starnet.NewSimpleRequest(server.URL, "GET").
|
||||
SetHeader("X-Custom", "value").
|
||||
AddQuery("name", "test")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
fmt.Println(resp.StatusCode)
|
||||
// Output: 200
|
||||
}
|
||||
|
||||
func ExampleClient_Get() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Client GET"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := starnet.NewClientNoErr(
|
||||
starnet.WithTimeout(10*time.Second),
|
||||
starnet.WithUserAgent("MyApp/1.0"),
|
||||
)
|
||||
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
fmt.Println(body)
|
||||
// Output: Client GET
|
||||
}
|
||||
|
||||
func ExampleRequest_SetJSON() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
type User struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
user := User{Name: "John", Email: "john@example.com"}
|
||||
|
||||
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
|
||||
SetJSON(user).
|
||||
Do()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]string
|
||||
resp.Body().JSON(&result)
|
||||
fmt.Println(result["status"])
|
||||
// Output: ok
|
||||
}
|
||||
|
||||
func ExampleRequest_AddFormData() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
fmt.Fprintf(w, "name=%s", r.FormValue("name"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
|
||||
AddFormData("name", "John").
|
||||
AddFormData("age", "30").
|
||||
Do()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
fmt.Println(body)
|
||||
// Output: name=John
|
||||
}
|
||||
|
||||
func ExampleRequest_SetSkipTLSVerify() {
|
||||
// This example shows how to skip TLS verification
|
||||
// Useful for testing with self-signed certificates
|
||||
|
||||
req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET").
|
||||
SetSkipTLSVerify(true)
|
||||
|
||||
// In a real scenario, you would call req.Do()
|
||||
fmt.Println(req.Method())
|
||||
// Output: GET
|
||||
}
|
||||
|
||||
func ExampleRequest_Clone() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseReq := starnet.NewSimpleRequest(server.URL, "GET").
|
||||
SetHeader("X-API-Key", "secret")
|
||||
|
||||
// Clone and modify
|
||||
req1 := baseReq.Clone().AddQuery("page", "1")
|
||||
req2 := baseReq.Clone().AddQuery("page", "2")
|
||||
|
||||
resp1, _ := req1.Do()
|
||||
resp2, _ := req2.Do()
|
||||
|
||||
defer resp1.Close()
|
||||
defer resp2.Close()
|
||||
|
||||
fmt.Println(resp1.StatusCode, resp2.StatusCode)
|
||||
// Output: 200 200
|
||||
}
|
||||
|
||||
func ExampleClient_SetDefaultSkipTLSVerify() {
|
||||
client := starnet.NewClientNoErr()
|
||||
client.SetDefaultSkipTLSVerify(true)
|
||||
|
||||
// All requests from this client will skip TLS verification
|
||||
// unless overridden at request level
|
||||
|
||||
fmt.Println("Client configured")
|
||||
// Output: Client configured
|
||||
}
|
||||
|
||||
func ExampleWithTimeout() {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := starnet.Get(server.URL,
|
||||
starnet.WithTimeout(200*time.Millisecond))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
fmt.Println(resp.StatusCode)
|
||||
// Output: 200
|
||||
}
|
||||
172
file_upload_test.go
Normal file
172
file_upload_test.go
Normal file
@ -0,0 +1,172 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestAddFileStream(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(10 << 20) // 10 MB
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Fatalf("FormFile error: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
content, _ := io.ReadAll(file)
|
||||
w.Write([]byte(header.Filename + ":" + string(content)))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
fileContent := "test file content"
|
||||
reader := strings.NewReader(fileContent)
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").
|
||||
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
expected := "test.txt:" + fileContent
|
||||
if body != expected {
|
||||
t.Errorf("Body = %v; want %v", body, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAddFileWithFormData(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(10 << 20)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||
}
|
||||
|
||||
// Check form field
|
||||
name := r.FormValue("name")
|
||||
if name != "John" {
|
||||
t.Errorf("name = %v; want John", name)
|
||||
}
|
||||
|
||||
// Check file
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Fatalf("FormFile error: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w.Write([]byte("OK:" + header.Filename))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
fileContent := "file data"
|
||||
reader := strings.NewReader(fileContent)
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").
|
||||
AddFormData("name", "John").
|
||||
AddFileStream("file", "document.txt", int64(len(fileContent)), reader)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if !strings.Contains(body, "document.txt") {
|
||||
t.Errorf("Body should contain filename, got: %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestUploadProgress(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseMultipartForm(10 << 20)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
progressCalled := false
|
||||
var lastUploaded int64
|
||||
|
||||
fileContent := strings.Repeat("a", 1024*10) // 10KB
|
||||
reader := strings.NewReader(fileContent)
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").
|
||||
SetUploadProgress(func(filename string, uploaded, total int64) {
|
||||
progressCalled = true
|
||||
lastUploaded = uploaded
|
||||
if filename != "test.txt" {
|
||||
t.Errorf("filename = %v; want test.txt", filename)
|
||||
}
|
||||
}).
|
||||
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if !progressCalled {
|
||||
t.Error("Progress callback was not called")
|
||||
}
|
||||
|
||||
if lastUploaded != int64(len(fileContent)) {
|
||||
t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAddFileFromDisk tests uploading a real file from disk
|
||||
func TestRequestAddFileFromDisk(t *testing.T) {
|
||||
// Create a temporary file
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
fileContent := []byte("test file content from disk")
|
||||
|
||||
err := os.WriteFile(tmpFile, fileContent, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteFile error: %v", err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(10 << 20)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Fatalf("FormFile error: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
content, _ := io.ReadAll(file)
|
||||
w.Write([]byte(header.Filename + ":" + string(content)))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile)
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if !strings.Contains(body, string(fileContent)) {
|
||||
t.Errorf("Body should contain file content, got: %v", body)
|
||||
}
|
||||
}
|
||||
140
header_test.go
Normal file
140
header_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestHeaders(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers := make(map[string]string)
|
||||
for k, v := range r.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
json.NewEncoder(w).Encode(headers)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetHeader("X-Custom-Header", "value1").
|
||||
AddHeader("X-Multi-Header", "value1").
|
||||
AddHeader("X-Multi-Header", "value2")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var headers map[string]string
|
||||
resp.Body().JSON(&headers)
|
||||
|
||||
if headers["X-Custom-Header"] != "value1" {
|
||||
t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCookies(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies := make(map[string]string)
|
||||
for _, cookie := range r.Cookies() {
|
||||
cookies[cookie.Name] = cookie.Value
|
||||
}
|
||||
json.NewEncoder(w).Encode(cookies)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
AddSimpleCookie("session", "abc123").
|
||||
AddSimpleCookie("user", "john")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var cookies map[string]string
|
||||
resp.Body().JSON(&cookies)
|
||||
|
||||
if cookies["session"] != "abc123" {
|
||||
t.Errorf("session cookie = %v; want abc123", cookies["session"])
|
||||
}
|
||||
if cookies["user"] != "john" {
|
||||
t.Errorf("user cookie = %v; want john", cookies["user"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestUserAgent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.UserAgent()))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetUserAgent("CustomAgent/1.0")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != "CustomAgent/1.0" {
|
||||
t.Errorf("User-Agent = %v; want CustomAgent/1.0", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBearerToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
w.Write([]byte(auth))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetBearerToken("mytoken123")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
expected := "Bearer mytoken123"
|
||||
if body != expected {
|
||||
t.Errorf("Authorization = %v; want %v", body, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBasicAuth(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Write([]byte(username + ":" + password))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetBasicAuth("user", "pass")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != "user:pass" {
|
||||
t.Errorf("BasicAuth = %v; want user:pass", body)
|
||||
}
|
||||
}
|
||||
120
httpguts.go
120
httpguts.go
@ -1,120 +0,0 @@
|
||||
package starnet
|
||||
|
||||
import "strings"
|
||||
|
||||
var isTokenTable = [127]bool{
|
||||
'!': true,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'W': true,
|
||||
'V': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'|': true,
|
||||
'~': true,
|
||||
}
|
||||
|
||||
func IsTokenRune(r rune) bool {
|
||||
i := int(r)
|
||||
return i < len(isTokenTable) && isTokenTable[i]
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
func isNotToken(r rune) bool {
|
||||
return !IsTokenRune(r)
|
||||
}
|
||||
|
||||
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
|
||||
|
||||
// removeEmptyPort strips the empty port in ":port" to ""
|
||||
// as mandated by RFC 3986 Section 6.2.3.
|
||||
func removeEmptyPort(host string) string {
|
||||
if hasPort(host) {
|
||||
return strings.TrimSuffix(host, ":")
|
||||
}
|
||||
return host
|
||||
}
|
||||
258
integration_test.go
Normal file
258
integration_test.go
Normal file
@ -0,0 +1,258 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 这些测试使用 httpbin.org 作为测试服务
|
||||
// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用
|
||||
|
||||
func skipIfNoIntegration(t *testing.T) {
|
||||
if os.Getenv("STARNET_INTEGRATION_TEST") != "1" {
|
||||
t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinGet(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
resp, err := Get("https://httpbin.org/get",
|
||||
WithQuery("name", "starnet"),
|
||||
WithQuery("version", "1.0"))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
args, ok := result["args"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("args not found in response")
|
||||
}
|
||||
|
||||
if args["name"] != "starnet" {
|
||||
t.Errorf("args[name] = %v; want starnet", args["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinPost(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
type PostData struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
data := PostData{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
}
|
||||
|
||||
resp, err := Post("https://httpbin.org/post", WithJSON(data))
|
||||
if err != nil {
|
||||
t.Fatalf("Post() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
jsonData, ok := result["json"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("json not found in response")
|
||||
}
|
||||
|
||||
if jsonData["name"] != data.Name {
|
||||
t.Errorf("name = %v; want %v", jsonData["name"], data.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinHeaders(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
resp, err := Get("https://httpbin.org/headers",
|
||||
WithHeader("X-Custom-Header", "test-value"),
|
||||
WithUserAgent("Starnet-Test/1.0"))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
headers, ok := result["headers"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("headers not found in response")
|
||||
}
|
||||
|
||||
if headers["X-Custom-Header"] != "test-value" {
|
||||
t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinBasicAuth(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
resp, err := Get("https://httpbin.org/basic-auth/user/passwd",
|
||||
WithBasicAuth("user", "passwd"))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
if result["authenticated"] != true {
|
||||
t.Error("authenticated should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinDelay(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
// Test timeout
|
||||
start := time.Now()
|
||||
_, err := Get("https://httpbin.org/delay/3",
|
||||
WithTimeout(1*time.Second))
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("Timeout took too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinRedirect(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
// Test with redirect enabled
|
||||
client := NewClientNoErr()
|
||||
resp, err := client.Get("https://httpbin.org/redirect/2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Test with redirect disabled
|
||||
client.DisableRedirect()
|
||||
resp2, err := client.Get("https://httpbin.org/redirect/2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp2.Close()
|
||||
|
||||
if resp2.StatusCode != 302 {
|
||||
t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinCookies(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
// 创建一个禁用重定向的 Client
|
||||
client := NewClientNoErr()
|
||||
client.DisableRedirect()
|
||||
|
||||
resp, err := client.Get("https://httpbin.org/cookies/set?name=value")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
// 现在应该能获取到 Set-Cookie
|
||||
cookies := resp.Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Error("Expected cookies in response")
|
||||
}
|
||||
|
||||
// 验证 cookie
|
||||
found := false
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == "name" && cookie.Value == "value" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected cookie 'name=value' not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinUserAgent(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
customUA := "Starnet-Integration-Test/1.0"
|
||||
resp, err := Get("https://httpbin.org/user-agent",
|
||||
WithUserAgent(customUA))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
if result["user-agent"] != customUA {
|
||||
t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationHTTPBinGzip(t *testing.T) {
|
||||
skipIfNoIntegration(t)
|
||||
|
||||
resp, err := Get("https://httpbin.org/gzip")
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("JSON() error: %v", err)
|
||||
}
|
||||
|
||||
if result["gzipped"] != true {
|
||||
t.Error("Response should be gzipped")
|
||||
}
|
||||
}
|
||||
390
options.go
Normal file
390
options.go
Normal file
@ -0,0 +1,390 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WithTimeout 设置请求总超时时间
|
||||
// timeout > 0: 使用该超时
|
||||
// timeout = 0: 使用 Client 默认超时
|
||||
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0)
|
||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Network.Timeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithDialTimeout 设置连接超时时间
|
||||
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Network.DialTimeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithProxy 设置代理
|
||||
func WithProxy(proxy string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Network.Proxy = proxy
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithDialFunc 设置自定义 Dial 函数
|
||||
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Network.DialFunc = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig 设置 TLS 配置
|
||||
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.TLS.Config = tlsConfig
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithSkipTLSVerify 设置是否跳过 TLS 验证
|
||||
func WithSkipTLSVerify(skip bool) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.TLS.SkipVerify = skip
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCustomIP 设置自定义 IP
|
||||
func WithCustomIP(ips []string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for _, ip := range ips {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||
}
|
||||
}
|
||||
r.config.DNS.CustomIP = ips
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAddCustomIP 添加自定义 IP
|
||||
func WithAddCustomIP(ip string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||
}
|
||||
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCustomDNS 设置自定义 DNS 服务器
|
||||
func WithCustomDNS(dnsServers []string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for _, dns := range dnsServers {
|
||||
if net.ParseIP(dns) == nil {
|
||||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||
}
|
||||
}
|
||||
r.config.DNS.CustomDNS = dnsServers
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAddCustomDNS 添加自定义 DNS 服务器
|
||||
func WithAddCustomDNS(dns string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
if net.ParseIP(dns) == nil {
|
||||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||
}
|
||||
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLookupFunc 设置自定义 DNS 解析函数
|
||||
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.DNS.LookupFunc = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeader 设置 Header
|
||||
func WithHeader(key, value string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Headers.Set(key, value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeaders 批量设置 Headers
|
||||
func WithHeaders(headers map[string]string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for k, v := range headers {
|
||||
r.config.Headers.Set(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithContentType 设置 Content-Type
|
||||
func WithContentType(contentType string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Headers.Set("Content-Type", contentType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserAgent 设置 User-Agent
|
||||
func WithUserAgent(userAgent string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Headers.Set("User-Agent", userAgent)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBearerToken 设置 Bearer Token
|
||||
func WithBearerToken(token string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Headers.Set("Authorization", "Bearer "+token)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBasicAuth 设置 Basic 认证
|
||||
func WithBasicAuth(username, password string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.BasicAuth = [2]string{username, password}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCookie 添加 Cookie
|
||||
func WithCookie(name, value, path string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: path,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithSimpleCookie 添加简单 Cookie(path 为 /)
|
||||
func WithSimpleCookie(name, value string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCookies 批量添加 Cookies
|
||||
func WithCookies(cookies map[string]string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for name, value := range cookies {
|
||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBody 设置请求体(字节)
|
||||
func WithBody(body []byte) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Body.Bytes = body
|
||||
r.config.Body.Reader = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBodyString 设置请求体(字符串)
|
||||
func WithBodyString(body string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Body.Bytes = []byte(body)
|
||||
r.config.Body.Reader = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBodyReader 设置请求体(Reader)
|
||||
func WithBodyReader(reader io.Reader) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Body.Reader = reader
|
||||
r.config.Body.Bytes = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithJSON 设置 JSON 请求体
|
||||
func WithJSON(v interface{}) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return wrapError(err, "marshal json")
|
||||
}
|
||||
r.config.Headers.Set("Content-Type", ContentTypeJSON)
|
||||
r.config.Body.Bytes = data
|
||||
r.config.Body.Reader = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithFormData 设置表单数据
|
||||
func WithFormData(data map[string][]string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Body.FormData = data
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithFormDataMap 设置表单数据(简化版)
|
||||
func WithFormDataMap(data map[string]string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for k, v := range data {
|
||||
r.config.Body.FormData[k] = []string{v}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAddFormData 添加表单数据
|
||||
func WithAddFormData(key, value string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithFile 添加文件
|
||||
func WithFile(formName, filePath string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: stat.Name(),
|
||||
FilePath: filePath,
|
||||
FileSize: stat.Size(),
|
||||
FileType: ContentTypeOctetStream,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithFileStream 添加文件流
|
||||
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
if reader == nil {
|
||||
return ErrNilReader
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: fileName,
|
||||
FileData: reader,
|
||||
FileSize: size,
|
||||
FileType: ContentTypeOctetStream,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithQuery 添加查询参数
|
||||
func WithQuery(key, value string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithQueries 批量添加查询参数
|
||||
func WithQueries(queries map[string]string) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
for k, v := range queries {
|
||||
r.config.Queries[k] = append(r.config.Queries[k], v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithContentLength 设置 Content-Length
|
||||
func WithContentLength(length int64) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.ContentLength = length
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||
func WithAutoCalcContentLength(auto bool) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.AutoCalcContentLength = auto
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithUploadProgress 设置文件上传进度回调
|
||||
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.UploadProgress = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransport 设置自定义 Transport
|
||||
func WithTransport(transport *http.Transport) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.config.Transport = transport
|
||||
r.config.CustomTransport = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoFetch 设置是否自动获取响应体
|
||||
func WithAutoFetch(auto bool) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.autoFetch = auto
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRawRequest 设置原始请求
|
||||
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.httpReq = httpReq
|
||||
r.doRaw = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext 设置 context
|
||||
func WithContext(ctx context.Context) RequestOpt {
|
||||
return func(r *Request) error {
|
||||
r.ctx = ctx
|
||||
r.httpReq = r.httpReq.WithContext(ctx)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
234
options_test.go
Normal file
234
options_test.go
Normal file
@ -0,0 +1,234 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWithJSONOpt(t *testing.T) {
|
||||
type payload struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
|
||||
t.Fatalf("content-type=%s", ct)
|
||||
}
|
||||
var p payload
|
||||
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||
t.Fatalf("decode err: %v", err)
|
||||
}
|
||||
if p.Name != "alice" || p.Age != 18 {
|
||||
t.Fatalf("payload mismatch: %+v", p)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
|
||||
if err != nil {
|
||||
t.Fatalf("Post error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
func TestWithFileOpt(t *testing.T) {
|
||||
// temp file + cleanup
|
||||
f, err := os.CreateTemp("", "starnet-upload-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
_, _ = f.WriteString("hello-file")
|
||||
_ = f.Close()
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
t.Fatalf("parse form err: %v", err)
|
||||
}
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
t.Fatalf("form file err: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
b, _ := io.ReadAll(file)
|
||||
if header.Filename == "" || string(b) != "hello-file" {
|
||||
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Post(s.URL, WithFile("file", f.Name()))
|
||||
if err != nil {
|
||||
t.Fatalf("Post error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
func TestWithFileStreamOpt(t *testing.T) {
|
||||
content := "stream-content"
|
||||
reader := strings.NewReader(content)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
t.Fatalf("parse form err: %v", err)
|
||||
}
|
||||
file, header, err := r.FormFile("up")
|
||||
if err != nil {
|
||||
t.Fatalf("form file err: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
b, _ := io.ReadAll(file)
|
||||
if header.Filename != "a.txt" || string(b) != content {
|
||||
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
|
||||
if err != nil {
|
||||
t.Fatalf("Post error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
func TestWithQueryOpt(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("k") != "v" {
|
||||
t.Fatalf("query mismatch: %v", r.URL.Query())
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Get(s.URL, WithQuery("k", "v"))
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
func TestWithUploadProgressOpt(t *testing.T) {
|
||||
var called int32
|
||||
var last int64
|
||||
content := strings.Repeat("x", 4096)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseMultipartForm(10 << 20)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Post(s.URL,
|
||||
WithUploadProgress(func(filename string, uploaded, total int64) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
last = uploaded
|
||||
}),
|
||||
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Post error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
|
||||
if atomic.LoadInt32(&called) == 0 {
|
||||
t.Fatal("progress not called")
|
||||
}
|
||||
if last != int64(len(content)) {
|
||||
t.Fatalf("last uploaded=%d want=%d", last, len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTransportOpt(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
}
|
||||
|
||||
func TestWithContextOpt(t *testing.T) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := Get(s.URL, WithContext(ctx))
|
||||
if err == nil {
|
||||
t.Fatal("expected context timeout error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
|
||||
if req.Err() != nil {
|
||||
t.Fatalf("unexpected err: %v", req.Err())
|
||||
}
|
||||
if len(req.config.DNS.CustomDNS) != 2 {
|
||||
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAddCustomIPOpt(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
|
||||
if req.Err() != nil {
|
||||
t.Fatalf("unexpected err: %v", req.Err())
|
||||
}
|
||||
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
|
||||
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCustomIPOpt(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
|
||||
if req.Err() != nil {
|
||||
t.Fatalf("unexpected err: %v", req.Err())
|
||||
}
|
||||
if len(req.config.DNS.CustomIP) != 2 {
|
||||
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithDialFuncOpt(t *testing.T) {
|
||||
called := int32(0)
|
||||
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
atomic.StoreInt32(&called, 1)
|
||||
return nil, io.EOF
|
||||
}
|
||||
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
|
||||
if req.config.Network.DialFunc == nil {
|
||||
t.Fatal("dial func not set")
|
||||
}
|
||||
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
|
||||
if atomic.LoadInt32(&called) == 0 {
|
||||
t.Fatal("dial func not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithDialTimeoutOpt(t *testing.T) {
|
||||
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
|
||||
if req.config.Network.DialTimeout != 123*time.Millisecond {
|
||||
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
|
||||
}
|
||||
}
|
||||
50
proxy_test.go
Normal file
50
proxy_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestProxy(t *testing.T) {
|
||||
// Create a proxy server
|
||||
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Proxy received the request
|
||||
w.Header().Set("X-Proxied", "true")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("proxied"))
|
||||
}))
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Note: This is a simplified test. Real proxy testing requires more setup
|
||||
req := NewSimpleRequest("http://example.com", "GET").
|
||||
SetProxy(proxyServer.URL)
|
||||
|
||||
// Just verify the proxy is set in config
|
||||
if req.config.Network.Proxy != proxyServer.URL {
|
||||
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientLevelProxy(t *testing.T) {
|
||||
proxyURL := "http://proxy.example.com:8080"
|
||||
client := NewClientNoErr(WithProxy(proxyURL))
|
||||
|
||||
req, _ := client.NewRequest("http://example.com", "GET")
|
||||
if req.config.Network.Proxy != proxyURL {
|
||||
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLevelProxyOverride(t *testing.T) {
|
||||
clientProxy := "http://client-proxy.com:8080"
|
||||
requestProxy := "http://request-proxy.com:8080"
|
||||
|
||||
client := NewClientNoErr(WithProxy(clientProxy))
|
||||
req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy))
|
||||
|
||||
// Request level should override client level
|
||||
if req.config.Network.Proxy != requestProxy {
|
||||
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy)
|
||||
}
|
||||
}
|
||||
98
query_test.go
Normal file
98
query_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestQuery(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
result := make(map[string][]string)
|
||||
for k, v := range query {
|
||||
result[k] = v
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
AddQuery("name", "John").
|
||||
AddQuery("age", "30").
|
||||
AddQuery("tags", "go").
|
||||
AddQuery("tags", "http")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string][]string
|
||||
resp.Body().JSON(&result)
|
||||
|
||||
if len(result["name"]) != 1 || result["name"][0] != "John" {
|
||||
t.Errorf("name = %v; want [John]", result["name"])
|
||||
}
|
||||
if len(result["tags"]) != 2 {
|
||||
t.Errorf("tags length = %v; want 2", len(result["tags"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSetQuery(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.Write([]byte(query.Get("key")))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetQuery("key", "value1").
|
||||
SetQuery("key", "value2") // Should overwrite
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != "value2" {
|
||||
t.Errorf("query value = %v; want value2", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDeleteQuery(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
result := make(map[string]string)
|
||||
for k := range query {
|
||||
result[k] = query.Get(k)
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
AddQuery("keep", "yes").
|
||||
AddQuery("delete", "no").
|
||||
DeleteQuery("delete")
|
||||
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result map[string]string
|
||||
resp.Body().JSON(&result)
|
||||
|
||||
if _, exists := result["delete"]; exists {
|
||||
t.Error("delete query should not exist")
|
||||
}
|
||||
if result["keep"] != "yes" {
|
||||
t.Errorf("keep = %v; want yes", result["keep"])
|
||||
}
|
||||
}
|
||||
332
request.go
Normal file
332
request.go
Normal file
@ -0,0 +1,332 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Request HTTP 请求
|
||||
type Request struct {
|
||||
ctx context.Context
|
||||
execCtx context.Context // 执行时的 context(注入了配置)
|
||||
url string
|
||||
method string
|
||||
err error // 累积的错误
|
||||
|
||||
config *RequestConfig
|
||||
client *Client
|
||||
httpClient *http.Client
|
||||
httpReq *http.Request
|
||||
|
||||
applied bool // 是否已应用配置
|
||||
doRaw bool // 是否使用原始请求(不修改)
|
||||
autoFetch bool // 是否自动获取响应体
|
||||
}
|
||||
|
||||
// newRequest 创建新请求(内部使用)
|
||||
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
|
||||
if method == "" {
|
||||
method = http.MethodGet
|
||||
}
|
||||
method = strings.ToUpper(method)
|
||||
|
||||
// 创建 http.Request
|
||||
httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, wrapError(err, "create http request")
|
||||
}
|
||||
|
||||
// 初始化配置
|
||||
config := &RequestConfig{
|
||||
Network: NetworkConfig{
|
||||
DialTimeout: DefaultDialTimeout,
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
Headers: make(http.Header),
|
||||
Queries: make(map[string][]string),
|
||||
Body: BodyConfig{
|
||||
FormData: make(map[string][]string),
|
||||
},
|
||||
}
|
||||
|
||||
// 设置默认 User-Agent
|
||||
config.Headers.Set("User-Agent", DefaultUserAgent)
|
||||
|
||||
// POST 请求默认 Content-Type
|
||||
if method == http.MethodPost {
|
||||
config.Headers.Set("Content-Type", ContentTypeFormURLEncoded)
|
||||
}
|
||||
|
||||
req := &Request{
|
||||
ctx: ctx,
|
||||
url: urlStr,
|
||||
method: method,
|
||||
config: config,
|
||||
httpReq: httpReq,
|
||||
autoFetch: DefaultFetchRespBody,
|
||||
}
|
||||
|
||||
// 应用选项
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if err := opt(req); err != nil {
|
||||
req.err = err
|
||||
return req, nil // 不返回错误,累积到 req.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewRequest 创建新请求
|
||||
func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
return newRequest(context.Background(), url, method, opts...)
|
||||
}
|
||||
|
||||
// NewRequestWithContext 创建新请求(带 context)
|
||||
func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||
return newRequest(ctx, url, method, opts...)
|
||||
}
|
||||
|
||||
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
|
||||
func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||
req, err := newRequest(context.Background(), url, method, opts...)
|
||||
if err != nil {
|
||||
// 返回一个带错误的请求
|
||||
return &Request{
|
||||
ctx: context.Background(),
|
||||
url: url,
|
||||
method: method,
|
||||
err: err,
|
||||
config: &RequestConfig{
|
||||
Headers: make(http.Header),
|
||||
Queries: make(map[string][]string),
|
||||
Body: BodyConfig{
|
||||
FormData: make(map[string][]string),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
||||
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||
req, err := newRequest(ctx, url, method, opts...)
|
||||
if err != nil {
|
||||
return &Request{
|
||||
ctx: ctx,
|
||||
url: url,
|
||||
method: method,
|
||||
err: err,
|
||||
config: &RequestConfig{
|
||||
Headers: make(http.Header),
|
||||
Queries: make(map[string][]string),
|
||||
Body: BodyConfig{
|
||||
FormData: make(map[string][]string),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// Clone 克隆请求
|
||||
func (r *Request) Clone() *Request {
|
||||
cloned := &Request{
|
||||
ctx: r.ctx,
|
||||
url: r.url,
|
||||
method: r.method,
|
||||
err: r.err,
|
||||
config: r.config.Clone(),
|
||||
client: r.client,
|
||||
httpClient: r.httpClient,
|
||||
applied: false, // 重置应用状态
|
||||
doRaw: r.doRaw,
|
||||
autoFetch: r.autoFetch,
|
||||
}
|
||||
|
||||
// 重新创建 http.Request
|
||||
if !r.doRaw {
|
||||
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
|
||||
} else {
|
||||
cloned.httpReq = r.httpReq
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
// Err 获取累积的错误
|
||||
func (r *Request) Err() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// Context 获取 context
|
||||
func (r *Request) Context() context.Context {
|
||||
return r.ctx
|
||||
}
|
||||
|
||||
// SetContext 设置 context
|
||||
func (r *Request) SetContext(ctx context.Context) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.ctx = ctx
|
||||
r.httpReq = r.httpReq.WithContext(ctx)
|
||||
return r
|
||||
}
|
||||
|
||||
// Method 获取 HTTP 方法
|
||||
func (r *Request) Method() string {
|
||||
return r.method
|
||||
}
|
||||
|
||||
// SetMethod 设置 HTTP 方法
|
||||
func (r *Request) SetMethod(method string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
method = strings.ToUpper(method)
|
||||
if !validMethod(method) {
|
||||
r.err = wrapError(ErrInvalidMethod, "method: %s", method)
|
||||
return r
|
||||
}
|
||||
|
||||
r.method = method
|
||||
r.httpReq.Method = method
|
||||
return r
|
||||
}
|
||||
|
||||
// URL 获取 URL
|
||||
func (r *Request) URL() string {
|
||||
return r.url
|
||||
}
|
||||
|
||||
// SetURL 设置 URL
|
||||
func (r *Request) SetURL(urlStr string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if r.doRaw {
|
||||
r.err = fmt.Errorf("cannot set URL when using raw request")
|
||||
return r
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
r.err = wrapError(ErrInvalidURL, "url: %s", urlStr)
|
||||
return r
|
||||
}
|
||||
|
||||
r.url = urlStr
|
||||
u.Host = removeEmptyPort(u.Host)
|
||||
r.httpReq.Host = u.Host
|
||||
r.httpReq.URL = u
|
||||
|
||||
// 更新 TLS ServerName
|
||||
if r.config.TLS.Config != nil {
|
||||
r.config.TLS.Config.ServerName = u.Hostname()
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// RawRequest 获取底层 http.Request
|
||||
func (r *Request) RawRequest() *http.Request {
|
||||
return r.httpReq
|
||||
}
|
||||
|
||||
// SetRawRequest 设置底层 http.Request(启用原始模式)
|
||||
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.httpReq = httpReq
|
||||
r.doRaw = true
|
||||
return r
|
||||
}
|
||||
|
||||
// EnableRawMode 启用原始模式(不修改请求)
|
||||
func (r *Request) EnableRawMode() *Request {
|
||||
r.doRaw = true
|
||||
return r
|
||||
}
|
||||
|
||||
// DisableRawMode 禁用原始模式
|
||||
func (r *Request) DisableRawMode() *Request {
|
||||
r.doRaw = false
|
||||
return r
|
||||
}
|
||||
|
||||
// SetAutoFetch 设置是否自动获取响应体
|
||||
func (r *Request) SetAutoFetch(auto bool) *Request {
|
||||
r.autoFetch = auto
|
||||
return r
|
||||
}
|
||||
|
||||
// Do 执行请求
|
||||
func (r *Request) Do() (*Response, error) {
|
||||
// 检查累积的错误
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
|
||||
// 准备请求
|
||||
if err := r.prepare(); err != nil {
|
||||
return nil, wrapError(err, "prepare request")
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
httpResp, err := r.httpClient.Do(r.httpReq)
|
||||
if err != nil {
|
||||
return &Response{
|
||||
Response: &http.Response{},
|
||||
request: r,
|
||||
httpClient: r.httpClient,
|
||||
body: &Body{},
|
||||
}, wrapError(err, "do request")
|
||||
}
|
||||
|
||||
// 创建响应
|
||||
resp := &Response{
|
||||
Response: httpResp,
|
||||
request: r,
|
||||
httpClient: r.httpClient,
|
||||
body: &Body{
|
||||
raw: httpResp.Body,
|
||||
},
|
||||
}
|
||||
|
||||
// 自动获取响应体
|
||||
if r.autoFetch {
|
||||
resp.body.readAll()
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Get 发送 GET 请求
|
||||
func (r *Request) Get() (*Response, error) {
|
||||
return r.SetMethod(http.MethodGet).Do()
|
||||
}
|
||||
|
||||
// Post 发送 POST 请求
|
||||
func (r *Request) Post() (*Response, error) {
|
||||
return r.SetMethod(http.MethodPost).Do()
|
||||
}
|
||||
|
||||
// Put 发送 PUT 请求
|
||||
func (r *Request) Put() (*Response, error) {
|
||||
return r.SetMethod(http.MethodPut).Do()
|
||||
}
|
||||
|
||||
// Delete 发送 DELETE 请求
|
||||
func (r *Request) Delete() (*Response, error) {
|
||||
return r.SetMethod(http.MethodDelete).Do()
|
||||
}
|
||||
463
request_body.go
Normal file
463
request_body.go
Normal file
@ -0,0 +1,463 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SetBody 设置请求体(字节)
|
||||
func (r *Request) SetBody(body []byte) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Body.Bytes = body
|
||||
r.config.Body.Reader = nil
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBodyReader 设置请求体(Reader)
|
||||
func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Body.Reader = reader
|
||||
r.config.Body.Bytes = nil
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBodyString 设置请求体(字符串)
|
||||
func (r *Request) SetBodyString(body string) *Request {
|
||||
return r.SetBody([]byte(body))
|
||||
}
|
||||
|
||||
// SetJSON 设置 JSON 请求体
|
||||
func (r *Request) SetJSON(v interface{}) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
r.err = wrapError(err, "marshal json")
|
||||
return r
|
||||
}
|
||||
|
||||
return r.SetContentType(ContentTypeJSON).SetBody(data)
|
||||
}
|
||||
|
||||
// SetFormData 设置表单数据(覆盖)
|
||||
func (r *Request) SetFormData(data map[string][]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Body.FormData = data
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFormData 添加表单数据
|
||||
func (r *Request) AddFormData(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFormDataMap 批量添加表单数据
|
||||
func (r *Request) AddFormDataMap(data map[string]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
for k, v := range data {
|
||||
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFile 添加文件(从路径)
|
||||
func (r *Request) AddFile(formName, filePath string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: stat.Name(),
|
||||
FilePath: filePath,
|
||||
FileSize: stat.Size(),
|
||||
FileType: ContentTypeOctetStream,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFileWithName 添加文件(指定文件名)
|
||||
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: fileName,
|
||||
FilePath: filePath,
|
||||
FileSize: stat.Size(),
|
||||
FileType: ContentTypeOctetStream,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFileWithType 添加文件(指定 MIME 类型)
|
||||
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: stat.Name(),
|
||||
FilePath: filePath,
|
||||
FileSize: stat.Size(),
|
||||
FileType: fileType,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFileStream 添加文件流
|
||||
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if reader == nil {
|
||||
r.err = ErrNilReader
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: fileName,
|
||||
FileData: reader,
|
||||
FileSize: size,
|
||||
FileType: ContentTypeOctetStream,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
|
||||
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if reader == nil {
|
||||
r.err = ErrNilReader
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||
FormName: formName,
|
||||
FileName: fileName,
|
||||
FileData: reader,
|
||||
FileSize: size,
|
||||
FileType: fileType,
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// applyBody 应用请求体
|
||||
func (r *Request) applyBody() error {
|
||||
// 优先级:Reader > Bytes > Files > FormData
|
||||
|
||||
// 1. Reader
|
||||
if r.config.Body.Reader != nil {
|
||||
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
||||
|
||||
// 尝试获取长度
|
||||
switch v := r.config.Body.Reader.(type) {
|
||||
case *bytes.Buffer:
|
||||
r.httpReq.ContentLength = int64(v.Len())
|
||||
case *bytes.Reader:
|
||||
r.httpReq.ContentLength = int64(v.Len())
|
||||
case *strings.Reader:
|
||||
r.httpReq.ContentLength = int64(v.Len())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Bytes
|
||||
if len(r.config.Body.Bytes) > 0 {
|
||||
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
|
||||
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. Files(multipart/form-data)
|
||||
if len(r.config.Body.Files) > 0 {
|
||||
return r.applyMultipartBody()
|
||||
}
|
||||
|
||||
// 4. FormData(application/x-www-form-urlencoded)
|
||||
if len(r.config.Body.FormData) > 0 {
|
||||
values := url.Values{}
|
||||
for k, vs := range r.config.Body.FormData {
|
||||
for _, v := range vs {
|
||||
values.Add(k, v)
|
||||
}
|
||||
}
|
||||
encoded := values.Encode()
|
||||
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
|
||||
r.httpReq.ContentLength = int64(len(encoded))
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyMultipartBody 应用 multipart 请求体
|
||||
func (r *Request) applyMultipartBody() error {
|
||||
pr, pw := io.Pipe()
|
||||
writer := multipart.NewWriter(pw)
|
||||
|
||||
// 设置 Content-Type
|
||||
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
r.httpReq.Body = pr
|
||||
|
||||
// 在 goroutine 中写入数据
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
defer writer.Close()
|
||||
|
||||
// 写入表单字段
|
||||
for k, vs := range r.config.Body.FormData {
|
||||
for _, v := range vs {
|
||||
if err := writer.WriteField(k, v); err != nil {
|
||||
pw.CloseWithError(wrapError(err, "write form field"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
for _, file := range r.config.Body.Files {
|
||||
if err := r.writeFile(writer, file); err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeFile 写入文件到 multipart writer
|
||||
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
|
||||
// 创建文件字段
|
||||
part, err := writer.CreateFormFile(file.FormName, file.FileName)
|
||||
if err != nil {
|
||||
return wrapError(err, "create form file")
|
||||
}
|
||||
|
||||
// 获取文件数据源
|
||||
var reader io.Reader
|
||||
if file.FileData != nil {
|
||||
reader = file.FileData
|
||||
} else if file.FilePath != "" {
|
||||
f, err := os.Open(file.FilePath)
|
||||
if err != nil {
|
||||
return wrapError(err, "open file")
|
||||
}
|
||||
defer f.Close()
|
||||
reader = f
|
||||
} else {
|
||||
return ErrNilReader
|
||||
}
|
||||
|
||||
// 复制文件数据(带进度)
|
||||
if r.config.UploadProgress != nil {
|
||||
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
|
||||
} else {
|
||||
_, err = io.Copy(part, reader)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return wrapError(err, "copy file data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepare 准备请求(应用配置)
|
||||
func (r *Request) prepare() error {
|
||||
if r.applied {
|
||||
return nil
|
||||
}
|
||||
defer func() { r.applied = true }()
|
||||
// 即使 raw 模式也要确保有 httpClient
|
||||
if r.httpClient == nil {
|
||||
var err error
|
||||
r.httpClient, err = r.buildHTTPClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 原始模式不修改请求内容
|
||||
if r.doRaw {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用查询参数
|
||||
if len(r.config.Queries) > 0 {
|
||||
q := r.httpReq.URL.Query()
|
||||
for k, values := range r.config.Queries {
|
||||
for _, v := range values {
|
||||
q.Add(k, v)
|
||||
}
|
||||
}
|
||||
r.httpReq.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
// 应用 Headers
|
||||
for k, values := range r.config.Headers {
|
||||
for _, v := range values {
|
||||
r.httpReq.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用 Cookies
|
||||
for _, cookie := range r.config.Cookies {
|
||||
r.httpReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// 应用 Basic Auth
|
||||
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
||||
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
||||
}
|
||||
|
||||
// 应用请求体
|
||||
if err := r.applyBody(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 应用 Content-Length
|
||||
if r.config.ContentLength > 0 {
|
||||
r.httpReq.ContentLength = r.config.ContentLength
|
||||
} else if r.config.ContentLength < 0 {
|
||||
r.httpReq.ContentLength = 0
|
||||
}
|
||||
|
||||
// 自动计算 Content-Length
|
||||
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
||||
data, err := io.ReadAll(r.httpReq.Body)
|
||||
if err != nil {
|
||||
return wrapError(err, "read body for content length")
|
||||
}
|
||||
r.httpReq.ContentLength = int64(len(data))
|
||||
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
|
||||
}
|
||||
|
||||
// 设置 TLS ServerName(如果有 TLS Config)
|
||||
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
|
||||
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
|
||||
}
|
||||
|
||||
// 注入配置到 context
|
||||
r.execCtx = injectRequestConfig(r.ctx, r.config)
|
||||
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHTTPClient 构建 HTTP Client
|
||||
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
||||
applyTimeoutOverride := func(base *http.Client) *http.Client {
|
||||
// 没有 base 时兜底
|
||||
if base == nil {
|
||||
base = &http.Client{}
|
||||
}
|
||||
|
||||
rt := r.config.Network.Timeout
|
||||
|
||||
// 语义:
|
||||
// rt < 0 : 本次请求禁用超时(Timeout = 0)
|
||||
// rt = 0 : 沿用 base.Timeout
|
||||
// rt > 0 : 本次请求超时覆盖
|
||||
if rt == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
clone := &http.Client{
|
||||
Transport: base.Transport,
|
||||
CheckRedirect: base.CheckRedirect,
|
||||
Jar: base.Jar,
|
||||
}
|
||||
|
||||
if rt < 0 {
|
||||
clone.Timeout = 0
|
||||
} else {
|
||||
clone.Timeout = rt
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// 优先使用请求关联的 Client
|
||||
if r.client != nil {
|
||||
return applyTimeoutOverride(r.client.HTTPClient()), nil
|
||||
}
|
||||
|
||||
// 自定义 Transport
|
||||
if r.config.CustomTransport && r.config.Transport != nil {
|
||||
base := &http.Client{
|
||||
Transport: &Transport{base: r.config.Transport},
|
||||
Timeout: 0,
|
||||
}
|
||||
return applyTimeoutOverride(base), nil
|
||||
}
|
||||
|
||||
// 默认全局 client
|
||||
return applyTimeoutOverride(DefaultHTTPClient()), nil
|
||||
}
|
||||
269
request_config.go
Normal file
269
request_config.go
Normal file
@ -0,0 +1,269 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SetTimeout 设置请求总超时时间
|
||||
// timeout > 0: 使用该超时
|
||||
// timeout = 0: 使用 Client 默认超时
|
||||
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0)
|
||||
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Network.Timeout = timeout
|
||||
return r
|
||||
}
|
||||
|
||||
// SetDialTimeout 设置连接超时时间
|
||||
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Network.DialTimeout = timeout
|
||||
return r
|
||||
}
|
||||
|
||||
// SetProxy 设置代理
|
||||
func (r *Request) SetProxy(proxy string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Network.Proxy = proxy
|
||||
return r
|
||||
}
|
||||
|
||||
// SetDialFunc 设置自定义 Dial 函数
|
||||
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Network.DialFunc = fn
|
||||
return r
|
||||
}
|
||||
|
||||
// SetTLSConfig 设置 TLS 配置
|
||||
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.TLS.Config = tlsConfig
|
||||
return r
|
||||
}
|
||||
|
||||
// SetSkipTLSVerify 设置是否跳过 TLS 验证
|
||||
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.TLS.SkipVerify = skip
|
||||
return r
|
||||
}
|
||||
|
||||
// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS)
|
||||
func (r *Request) SetCustomIP(ips []string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
// 验证 IP 格式
|
||||
for _, ip := range ips {
|
||||
if net.ParseIP(ip) == nil {
|
||||
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
r.config.DNS.CustomIP = ips
|
||||
return r
|
||||
}
|
||||
|
||||
// AddCustomIP 添加自定义 IP
|
||||
func (r *Request) AddCustomIP(ip string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if net.ParseIP(ip) == nil {
|
||||
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetCustomDNS 设置自定义 DNS 服务器
|
||||
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
// 验证 DNS 服务器格式
|
||||
for _, dns := range dnsServers {
|
||||
if net.ParseIP(dns) == nil {
|
||||
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
r.config.DNS.CustomDNS = dnsServers
|
||||
return r
|
||||
}
|
||||
|
||||
// AddCustomDNS 添加自定义 DNS 服务器
|
||||
func (r *Request) AddCustomDNS(dns string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if net.ParseIP(dns) == nil {
|
||||
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetLookupFunc 设置自定义 DNS 解析函数
|
||||
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.DNS.LookupFunc = fn
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBasicAuth 设置 Basic 认证
|
||||
func (r *Request) SetBasicAuth(username, password string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.BasicAuth = [2]string{username, password}
|
||||
return r
|
||||
}
|
||||
|
||||
// SetContentLength 设置 Content-Length
|
||||
func (r *Request) SetContentLength(length int64) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.ContentLength = length
|
||||
return r
|
||||
}
|
||||
|
||||
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||
// 警告:启用后会将整个 body 读入内存
|
||||
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
if r.doRaw {
|
||||
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
|
||||
return r
|
||||
}
|
||||
|
||||
r.config.AutoCalcContentLength = auto
|
||||
return r
|
||||
}
|
||||
|
||||
// SetTransport 设置自定义 Transport
|
||||
func (r *Request) SetTransport(transport *http.Transport) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Transport = transport
|
||||
r.config.CustomTransport = true
|
||||
return r
|
||||
}
|
||||
|
||||
// SetUploadProgress 设置文件上传进度回调
|
||||
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.UploadProgress = fn
|
||||
return r
|
||||
}
|
||||
|
||||
// AddQuery 添加查询参数
|
||||
func (r *Request) AddQuery(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetQuery 设置查询参数(覆盖)
|
||||
func (r *Request) SetQuery(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Queries[key] = []string{value}
|
||||
return r
|
||||
}
|
||||
|
||||
// SetQueries 设置所有查询参数(覆盖)
|
||||
func (r *Request) SetQueries(queries map[string][]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Queries = queries
|
||||
return r
|
||||
}
|
||||
|
||||
// AddQueries 批量添加查询参数
|
||||
func (r *Request) AddQueries(queries map[string]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
for k, v := range queries {
|
||||
r.config.Queries[k] = append(r.config.Queries[k], v)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// DeleteQuery 删除查询参数
|
||||
func (r *Request) DeleteQuery(key string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
delete(r.config.Queries, key)
|
||||
return r
|
||||
}
|
||||
|
||||
// DeleteQueryValue 删除查询参数的特定值
|
||||
func (r *Request) DeleteQueryValue(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
values, ok := r.config.Queries[key]
|
||||
if !ok {
|
||||
return r
|
||||
}
|
||||
|
||||
newValues := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if v != value {
|
||||
newValues = append(newValues, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(newValues) == 0 {
|
||||
delete(r.config.Queries, key)
|
||||
} else {
|
||||
r.config.Queries[key] = newValues
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
180
request_header.go
Normal file
180
request_header.go
Normal file
@ -0,0 +1,180 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SetHeader 设置 Header(覆盖)
|
||||
func (r *Request) SetHeader(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Headers.Set(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// AddHeader 添加 Header
|
||||
func (r *Request) AddHeader(key, value string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Headers.Add(key, value)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetHeaders 设置所有 Headers(覆盖)
|
||||
func (r *Request) SetHeaders(headers http.Header) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Headers = headers
|
||||
return r
|
||||
}
|
||||
|
||||
// AddHeaders 批量添加 Headers
|
||||
func (r *Request) AddHeaders(headers map[string]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
for k, v := range headers {
|
||||
r.config.Headers.Add(k, v)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// DeleteHeader 删除 Header
|
||||
func (r *Request) DeleteHeader(key string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Headers.Del(key)
|
||||
return r
|
||||
}
|
||||
|
||||
// GetHeader 获取 Header
|
||||
func (r *Request) GetHeader(key string) string {
|
||||
return r.config.Headers.Get(key)
|
||||
}
|
||||
|
||||
// Headers 获取所有 Headers
|
||||
func (r *Request) Headers() http.Header {
|
||||
return r.config.Headers
|
||||
}
|
||||
|
||||
// SetContentType 设置 Content-Type
|
||||
func (r *Request) SetContentType(contentType string) *Request {
|
||||
return r.SetHeader("Content-Type", contentType)
|
||||
}
|
||||
|
||||
// SetUserAgent 设置 User-Agent
|
||||
func (r *Request) SetUserAgent(userAgent string) *Request {
|
||||
return r.SetHeader("User-Agent", userAgent)
|
||||
}
|
||||
|
||||
// SetReferer 设置 Referer
|
||||
func (r *Request) SetReferer(referer string) *Request {
|
||||
return r.SetHeader("Referer", referer)
|
||||
}
|
||||
|
||||
// SetBearerToken 设置 Bearer Token
|
||||
func (r *Request) SetBearerToken(token string) *Request {
|
||||
return r.SetHeader("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
// AddCookie 添加 Cookie
|
||||
func (r *Request) AddCookie(cookie *http.Cookie) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Cookies = append(r.config.Cookies, cookie)
|
||||
return r
|
||||
}
|
||||
|
||||
// AddSimpleCookie 添加简单 Cookie(path 为 /)
|
||||
func (r *Request) AddSimpleCookie(name, value string) *Request {
|
||||
return r.AddCookie(&http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
// AddCookieKV 添加 Cookie(指定 path)
|
||||
func (r *Request) AddCookieKV(name, value, path string) *Request {
|
||||
return r.AddCookie(&http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: path,
|
||||
})
|
||||
}
|
||||
|
||||
// SetCookies 设置所有 Cookies(覆盖)
|
||||
func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
r.config.Cookies = cookies
|
||||
return r
|
||||
}
|
||||
|
||||
// AddCookies 批量添加 Cookies
|
||||
func (r *Request) AddCookies(cookies map[string]string) *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
if r.doRaw {
|
||||
return r
|
||||
}
|
||||
for name, value := range cookies {
|
||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Cookies 获取所有 Cookies
|
||||
func (r *Request) Cookies() []*http.Cookie {
|
||||
return r.config.Cookies
|
||||
}
|
||||
|
||||
// ResetHeaders 重置所有 Headers
|
||||
func (r *Request) ResetHeaders() *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Headers = make(http.Header)
|
||||
return r
|
||||
}
|
||||
|
||||
// ResetCookies 重置所有 Cookies
|
||||
func (r *Request) ResetCookies() *Request {
|
||||
if r.err != nil {
|
||||
return r
|
||||
}
|
||||
r.config.Cookies = []*http.Cookie{}
|
||||
return r
|
||||
}
|
||||
172
request_test.go
Normal file
172
request_test.go
Normal file
@ -0,0 +1,172 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewSimpleRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
method string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid GET request",
|
||||
url: "https://example.com",
|
||||
method: "GET",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid POST request",
|
||||
url: "https://example.com",
|
||||
method: "POST",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
url: "://invalid",
|
||||
method: "GET",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := NewRequest(tt.url, tt.method)
|
||||
if tt.expectErr {
|
||||
if err == nil && req.Err() == nil {
|
||||
t.Errorf("NewRequest() expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("NewRequest() unexpected error: %v", err)
|
||||
}
|
||||
if req.Method() != strings.ToUpper(tt.method) {
|
||||
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestMethods(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Method", r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(r.Method))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
req := NewSimpleRequest(server.URL, method)
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
if method != "HEAD" {
|
||||
body, _ := resp.Body().String()
|
||||
if body != method {
|
||||
t.Errorf("Body = %v; want %v", body, method)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSetMethod(t *testing.T) {
|
||||
req := NewSimpleRequest("https://example.com", "GET")
|
||||
|
||||
req.SetMethod("POST")
|
||||
if req.Method() != "POST" {
|
||||
t.Errorf("Method = %v; want POST", req.Method())
|
||||
}
|
||||
|
||||
req.SetMethod("invalid method!")
|
||||
if req.Err() == nil {
|
||||
t.Error("SetMethod with invalid method should set error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSetURL(t *testing.T) {
|
||||
req := NewSimpleRequest("https://example.com", "GET")
|
||||
|
||||
req.SetURL("https://newexample.com")
|
||||
if req.URL() != "https://newexample.com" {
|
||||
t.Errorf("URL = %v; want https://newexample.com", req.URL())
|
||||
}
|
||||
|
||||
req2 := NewSimpleRequest("https://example.com", "GET")
|
||||
req2.SetURL("://invalid")
|
||||
if req2.Err() == nil {
|
||||
t.Error("SetURL with invalid URL should set error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestClone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("X-Test") != "value" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").
|
||||
SetHeader("X-Test", "value")
|
||||
|
||||
// 第一次请求
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
resp.Close()
|
||||
|
||||
// 克隆请求
|
||||
cloned := req.Clone()
|
||||
cloned.SetHeader("X-Extra", "extra")
|
||||
|
||||
// 克隆的请求应该也能成功
|
||||
resp2, err := cloned.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Cloned Do() error: %v", err)
|
||||
}
|
||||
defer resp2.Close()
|
||||
|
||||
if resp2.StatusCode != http.StatusOK {
|
||||
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestContext(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
|
||||
_, err := req.Do()
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
161
response.go
Normal file
161
response.go
Normal file
@ -0,0 +1,161 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Response HTTP 响应
|
||||
type Response struct {
|
||||
*http.Response
|
||||
request *Request
|
||||
httpClient *http.Client
|
||||
body *Body
|
||||
}
|
||||
|
||||
// Body 响应体
|
||||
type Body struct {
|
||||
raw io.ReadCloser
|
||||
data []byte
|
||||
consumed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Request 获取原始请求
|
||||
func (r *Response) Request() *Request {
|
||||
return r.request
|
||||
}
|
||||
|
||||
// Body 获取响应体
|
||||
func (r *Response) Body() *Body {
|
||||
return r.body
|
||||
}
|
||||
|
||||
// Close 关闭响应体
|
||||
func (r *Response) Close() error {
|
||||
if r.body != nil && r.body.raw != nil {
|
||||
return r.body.raw.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWithClient 关闭响应体并关闭空闲连接
|
||||
func (r *Response) CloseWithClient() error {
|
||||
if r.httpClient != nil {
|
||||
r.httpClient.CloseIdleConnections()
|
||||
}
|
||||
return r.Close()
|
||||
}
|
||||
|
||||
// readAll 读取所有数据
|
||||
func (b *Body) readAll() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.consumed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.raw == nil {
|
||||
b.consumed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(b.raw)
|
||||
if err != nil {
|
||||
return wrapError(err, "read response body")
|
||||
}
|
||||
|
||||
b.data = data
|
||||
b.consumed = true
|
||||
b.raw.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes 获取响应体字节
|
||||
func (b *Body) Bytes() ([]byte, error) {
|
||||
if err := b.readAll(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.data, nil
|
||||
}
|
||||
|
||||
// String 获取响应体字符串
|
||||
func (b *Body) String() (string, error) {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// JSON 解析 JSON 响应
|
||||
func (b *Body) JSON(v interface{}) error {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// Reader 获取 Reader(只能调用一次)
|
||||
func (b *Body) Reader() (io.ReadCloser, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.consumed {
|
||||
if b.data != nil {
|
||||
// 已读取,返回缓存数据的 Reader
|
||||
return io.NopCloser(bytes.NewReader(b.data)), nil
|
||||
}
|
||||
return nil, ErrBodyAlreadyConsumed
|
||||
}
|
||||
|
||||
b.consumed = true
|
||||
return b.raw, nil
|
||||
}
|
||||
|
||||
// IsConsumed 检查是否已消费
|
||||
func (b *Body) IsConsumed() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.consumed
|
||||
}
|
||||
|
||||
// Close 关闭 Body
|
||||
func (b *Body) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.raw != nil {
|
||||
return b.raw.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustBytes 获取响应体字节(忽略错误,失败返回 nil)
|
||||
func (b *Body) MustBytes() []byte {
|
||||
data, err := b.Bytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// MustString 获取响应体字符串(忽略错误,失败返回空串)
|
||||
func (b *Body) MustString() string {
|
||||
s, err := b.String()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Unmarshal 解析 JSON 响应(兼容旧 API)
|
||||
func (b *Body) Unmarshal(v interface{}) error {
|
||||
return b.JSON(v)
|
||||
}
|
||||
179
response_test.go
Normal file
179
response_test.go
Normal file
@ -0,0 +1,179 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResponseBody(t *testing.T) {
|
||||
testData := "test response data"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(testData))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
// Test String()
|
||||
body, err := resp.Body().String()
|
||||
if err != nil {
|
||||
t.Fatalf("Body().String() error: %v", err)
|
||||
}
|
||||
if body != testData {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
|
||||
// Test multiple reads (should work because body is cached)
|
||||
body2, err := resp.Body().String()
|
||||
if err != nil {
|
||||
t.Fatalf("Second Body().String() error: %v", err)
|
||||
}
|
||||
if body2 != testData {
|
||||
t.Errorf("Second Body = %v; want %v", body2, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"name":"John","age":30}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
var result struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
err = resp.Body().JSON(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Body().JSON() error: %v", err)
|
||||
}
|
||||
|
||||
if result.Name != "John" {
|
||||
t.Errorf("Name = %v; want John", result.Name)
|
||||
}
|
||||
if result.Age != 30 {
|
||||
t.Errorf("Age = %v; want 30", result.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseBytes(t *testing.T) {
|
||||
testData := []byte("binary data")
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testData)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, err := resp.Body().Bytes()
|
||||
if err != nil {
|
||||
t.Fatalf("Body().Bytes() error: %v", err)
|
||||
}
|
||||
|
||||
if string(body) != string(testData) {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseReader(t *testing.T) {
|
||||
testData := "stream data"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(testData))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
reader, err := resp.Body().Reader()
|
||||
if err != nil {
|
||||
t.Fatalf("Body().Reader() error: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error: %v", err)
|
||||
}
|
||||
|
||||
if string(body) != testData {
|
||||
t.Errorf("Body = %v; want %v", string(body), testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseAutoFetch(t *testing.T) {
|
||||
testData := "auto fetch data"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(testData))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// With auto fetch
|
||||
resp, err := Get(server.URL, WithAutoFetch(true))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if !resp.Body().IsConsumed() {
|
||||
t.Error("Body should be consumed with auto fetch")
|
||||
}
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != testData {
|
||||
t.Errorf("Body = %v; want %v", body, testData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
}{
|
||||
{"OK", http.StatusOK},
|
||||
{"Created", http.StatusCreated},
|
||||
{"BadRequest", http.StatusBadRequest},
|
||||
{"NotFound", http.StatusNotFound},
|
||||
{"InternalServerError", http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != tt.statusCode {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
66
timeout_test.go
Normal file
66
timeout_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRequestTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Should timeout
|
||||
req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond)
|
||||
_, err := req.Do()
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
// Should succeed
|
||||
req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond)
|
||||
resp, err := req2.Do()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDialTimeout(t *testing.T) {
|
||||
// Use a non-routable IP to test dial timeout
|
||||
req := NewSimpleRequest("http://192.0.2.1:80", "GET").
|
||||
SetDialTimeout(100 * time.Millisecond)
|
||||
|
||||
start := time.Now()
|
||||
_, err := req.Do()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected dial timeout error, got nil")
|
||||
}
|
||||
|
||||
// Should timeout within reasonable time (not wait forever)
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("Dial timeout took too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTimeout(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr(WithTimeout(100 * time.Millisecond))
|
||||
_, err := client.Get(server.URL)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
102
tls_test.go
Normal file
102
tls_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestSkipTLSVerify(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Without skip verify (should fail)
|
||||
req := NewSimpleRequest(server.URL, "GET")
|
||||
_, err := req.Do()
|
||||
if err == nil {
|
||||
t.Error("Expected TLS error without skip verify, got nil")
|
||||
}
|
||||
|
||||
// With skip verify (should succeed)
|
||||
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
|
||||
resp, err := req2.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() with skip verify error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
body, _ := resp.Body().String()
|
||||
if body != "OK" {
|
||||
t.Errorf("Body = %v; want OK", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestCustomTLSConfig(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Fatalf("Do() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDefaultTLSConfig(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClientNoErr()
|
||||
client.SetDefaultSkipTLSVerify(true)
|
||||
|
||||
resp, err := client.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLevelTLSOverride(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Client level: skip verify = false
|
||||
client := NewClientNoErr()
|
||||
client.SetDefaultSkipTLSVerify(false)
|
||||
|
||||
// Request level: skip verify = true (should override)
|
||||
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error: %v", err)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
55
tlsconfig.go
Normal file
55
tlsconfig.go
Normal file
@ -0,0 +1,55 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetConfigForClientFunc selects TLS config by hostname/SNI.
|
||||
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
|
||||
|
||||
// ListenerConfig controls listener behavior.
|
||||
type ListenerConfig struct {
|
||||
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
|
||||
BaseTLSConfig *tls.Config
|
||||
|
||||
// GetConfigForClient selects TLS config for a hostname.
|
||||
GetConfigForClient GetConfigForClientFunc
|
||||
|
||||
// AllowNonTLS allows plain TCP fallback.
|
||||
AllowNonTLS bool
|
||||
|
||||
// SniffTimeout bounds protocol sniffing time. 0 means no timeout.
|
||||
SniffTimeout time.Duration
|
||||
|
||||
// MaxClientHelloBytes limits buffered sniff data.
|
||||
// If <= 0, default 64KiB.
|
||||
MaxClientHelloBytes int
|
||||
|
||||
// Logger is optional.
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// DefaultListenerConfig returns a conservative default config.
|
||||
func DefaultListenerConfig() ListenerConfig {
|
||||
return ListenerConfig{
|
||||
AllowNonTLS: false,
|
||||
SniffTimeout: 5 * time.Second,
|
||||
MaxClientHelloBytes: 64 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
// TLSDefaults returns a TLS config baseline.
|
||||
// Caller should set Certificates / GetCertificate as needed.
|
||||
func TLSDefaults() *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
// DialConfig controls dialing behavior.
|
||||
type DialConfig struct {
|
||||
Timeout time.Duration
|
||||
LocalAddr net.Addr
|
||||
}
|
||||
699
tlssniffer.go
699
tlssniffer.go
@ -10,269 +10,231 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type myConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
isReadOnly bool
|
||||
multiReader io.Reader
|
||||
// replayConn replays buffered bytes first, then reads from live conn.
|
||||
type replayConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (c *myConn) Read(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return c.reader.Read(p)
|
||||
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
|
||||
return &replayConn{
|
||||
reader: io.MultiReader(buffered, conn),
|
||||
conn: conn,
|
||||
}
|
||||
if c.multiReader == nil {
|
||||
c.multiReader = io.MultiReader(c.reader, c.conn)
|
||||
}
|
||||
return c.multiReader.Read(p)
|
||||
}
|
||||
|
||||
func (c *myConn) Write(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
func (c *myConn) Close() error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.Close()
|
||||
}
|
||||
func (c *myConn) LocalAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
func (c *myConn) RemoteAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
func (c *myConn) SetDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
||||
func (c *replayConn) Close() error { return c.conn.Close() }
|
||||
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||
|
||||
// SniffResult describes protocol sniffing result.
|
||||
type SniffResult struct {
|
||||
IsTLS bool
|
||||
Hostname string
|
||||
Buffer *bytes.Buffer
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
cfg *tls.Config
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
allowNonTls bool
|
||||
// Sniffer detects protocol and metadata from initial bytes.
|
||||
type Sniffer interface {
|
||||
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
|
||||
}
|
||||
|
||||
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
||||
return l.getConfigForClient
|
||||
}
|
||||
// TLSSniffer is the default sniffer implementation.
|
||||
type TLSSniffer struct{}
|
||||
|
||||
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
||||
l.getConfigForClient = getConfigForClient
|
||||
}
|
||||
|
||||
func Listen(network, address string) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: listener}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := liscfg.Listen(context.Background(), network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Sniff detects TLS and extracts SNI when possible.
|
||||
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 64 * 1024
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
||||
tee := io.TeeReader(limited, &buf)
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: tlsConfig,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
tlsCfg: l.cfg,
|
||||
getConfigForClient: l.getConfigForClient,
|
||||
allowNonTls: l.allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
once sync.Once
|
||||
initErr error
|
||||
isTLS bool
|
||||
tlsCfg *tls.Config
|
||||
tlsConn *tls.Conn
|
||||
buffer *bytes.Buffer
|
||||
noTlsReader io.Reader
|
||||
isOriginal bool
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
hostname string
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (c *Conn) Hostname() string {
|
||||
if c.hostname != "" {
|
||||
return c.hostname
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
if c.tlsConn.ConnectionState().ServerName != "" {
|
||||
c.hostname = c.tlsConn.ConnectionState().ServerName
|
||||
return c.hostname
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Conn) IsTLS() bool {
|
||||
return c.isTLS
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConn() *tls.Conn {
|
||||
return c.tlsConn
|
||||
}
|
||||
|
||||
func (c *Conn) isTLSConnection() (bool, error) {
|
||||
if c.getConfigForClient == nil {
|
||||
peek := make([]byte, 5)
|
||||
n, err := io.ReadFull(c.Conn, peek)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
|
||||
c.buffer = bytes.NewBuffer(peek[:n])
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
c.buffer = new(bytes.Buffer)
|
||||
r := io.TeeReader(c.Conn, c.buffer)
|
||||
var hello *tls.ClientHelloInfo
|
||||
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
_ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
cp := *ch
|
||||
hello = &cp
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
peek := c.buffer.Bytes()
|
||||
n := len(peek)
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
if hello == nil {
|
||||
return isTLS, nil
|
||||
|
||||
peek := buf.Bytes()
|
||||
isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
|
||||
out := SniffResult{
|
||||
IsTLS: isTLS,
|
||||
Buffer: bytes.NewBuffer(append([]byte(nil), peek...)),
|
||||
}
|
||||
c.hostname = hello.ServerName
|
||||
if c.hostname == "" {
|
||||
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
|
||||
if hello != nil {
|
||||
out.Hostname = hello.ServerName
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// readOnlyConn rejects writes/close and reads from a reader.
|
||||
type readOnlyConn struct {
|
||||
r io.Reader
|
||||
raw net.Conn
|
||||
}
|
||||
|
||||
func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
||||
func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (c readOnlyConn) Close() error { return nil }
|
||||
func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() }
|
||||
func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() }
|
||||
func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
// Conn wraps net.Conn with lazy protocol initialization.
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
|
||||
once sync.Once
|
||||
initErr error
|
||||
closeOnce sync.Once
|
||||
|
||||
isTLS bool
|
||||
tlsConn *tls.Conn
|
||||
plainConn net.Conn
|
||||
|
||||
hostname string
|
||||
|
||||
baseTLSConfig *tls.Config
|
||||
getConfigForClient GetConfigForClientFunc
|
||||
allowNonTLS bool
|
||||
sniffer Sniffer
|
||||
sniffTimeout time.Duration
|
||||
maxClientHello int
|
||||
logger Logger
|
||||
stats *Stats
|
||||
skipSniff bool
|
||||
}
|
||||
|
||||
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
||||
return &Conn{
|
||||
Conn: raw,
|
||||
plainConn: raw,
|
||||
baseTLSConfig: cfg.BaseTLSConfig,
|
||||
getConfigForClient: cfg.GetConfigForClient,
|
||||
allowNonTLS: cfg.AllowNonTLS,
|
||||
sniffer: TLSSniffer{},
|
||||
sniffTimeout: cfg.SniffTimeout,
|
||||
maxClientHello: cfg.MaxClientHelloBytes,
|
||||
logger: cfg.Logger,
|
||||
stats: stats,
|
||||
}
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
func (c *Conn) init() {
|
||||
c.once.Do(func() {
|
||||
if c.isOriginal {
|
||||
if c.skipSniff {
|
||||
return
|
||||
}
|
||||
if c.tlsCfg != nil {
|
||||
isTLS, err := c.isTLSConnection()
|
||||
if err != nil {
|
||||
c.initErr = err
|
||||
return
|
||||
}
|
||||
c.isTLS = isTLS
|
||||
if c.baseTLSConfig == nil && c.getConfigForClient == nil {
|
||||
c.isTLS = false
|
||||
return
|
||||
}
|
||||
|
||||
if c.sniffTimeout > 0 {
|
||||
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
|
||||
}
|
||||
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
|
||||
if c.sniffTimeout > 0 {
|
||||
_ = c.Conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
if err != nil {
|
||||
c.initErr = err
|
||||
c.failAndClose("sniff failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.isTLS = res.IsTLS
|
||||
c.hostname = res.Hostname
|
||||
|
||||
if c.isTLS {
|
||||
var cfg = c.tlsCfg
|
||||
if c.getConfigForClient != nil {
|
||||
cfg = c.getConfigForClient(c.hostname)
|
||||
if cfg == nil {
|
||||
cfg = c.tlsCfg
|
||||
}
|
||||
if c.stats != nil {
|
||||
c.stats.incTLSDetected()
|
||||
}
|
||||
c.tlsConn = tls.Server(&myConn{
|
||||
reader: c.buffer,
|
||||
conn: c.Conn,
|
||||
isReadOnly: false,
|
||||
}, cfg)
|
||||
} else {
|
||||
if !c.allowNonTls {
|
||||
c.initErr = net.ErrClosed
|
||||
tlsCfg, errCfg := c.selectTLSConfig()
|
||||
if errCfg != nil {
|
||||
c.initErr = errCfg
|
||||
c.failAndClose("tls config select failed: %v", errCfg)
|
||||
return
|
||||
}
|
||||
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
||||
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||
c.tlsConn = tls.Server(rc, tlsCfg)
|
||||
return
|
||||
}
|
||||
|
||||
if c.stats != nil {
|
||||
c.stats.incPlainDetected()
|
||||
}
|
||||
if !c.allowNonTLS {
|
||||
c.initErr = ErrNonTLSNotAllowed
|
||||
c.failAndClose("plain tcp rejected")
|
||||
return
|
||||
}
|
||||
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) failAndClose(format string, v ...interface{}) {
|
||||
if c.stats != nil {
|
||||
c.stats.incInitFailures()
|
||||
}
|
||||
if c.logger != nil {
|
||||
c.logger.Printf("starnet: "+format, v...)
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
||||
if c.getConfigForClient != nil {
|
||||
cfg, err := c.getConfigForClient(c.hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg != nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
if c.baseTLSConfig != nil {
|
||||
return c.baseTLSConfig, nil
|
||||
}
|
||||
return nil, ErrNoTLSConfig
|
||||
}
|
||||
|
||||
// Hostname returns sniffed SNI hostname (if any).
|
||||
func (c *Conn) Hostname() string {
|
||||
c.init()
|
||||
return c.hostname
|
||||
}
|
||||
|
||||
func (c *Conn) IsTLS() bool {
|
||||
c.init()
|
||||
return c.initErr == nil && c.isTLS
|
||||
}
|
||||
|
||||
func (c *Conn) TLSConn() (*tls.Conn, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return nil, c.initErr
|
||||
}
|
||||
if !c.isTLS || c.tlsConn == nil {
|
||||
return nil, ErrNotTLS
|
||||
}
|
||||
return c.tlsConn, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
@ -281,7 +243,7 @@ func (c *Conn) Read(b []byte) (int, error) {
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Read(b)
|
||||
}
|
||||
return c.noTlsReader.Read(b)
|
||||
return c.plainConn.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
@ -289,113 +251,250 @@ func (c *Conn) Write(b []byte) (int, error) {
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Write(b)
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
return c.plainConn.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.Close()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
if c.tlsConn != nil {
|
||||
err = c.tlsConn.Close()
|
||||
} else {
|
||||
err = c.Conn.Close()
|
||||
}
|
||||
if c.stats != nil {
|
||||
c.stats.incClosed()
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return c.initErr
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetDeadline(t)
|
||||
}
|
||||
return c.Conn.SetDeadline(t)
|
||||
return c.plainConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return c.initErr
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetReadDeadline(t)
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
return c.plainConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return c.initErr
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetWriteDeadline(t)
|
||||
}
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
return c.plainConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
||||
if c.initErr != nil {
|
||||
return nil, c.initErr
|
||||
}
|
||||
if !c.isTLS {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c.tlsConn, nil
|
||||
// Listener wraps net.Listener and returns starnet.Conn from Accept.
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
|
||||
mu sync.RWMutex
|
||||
cfg ListenerConfig
|
||||
stats Stats
|
||||
}
|
||||
|
||||
func (c *Conn) OriginalConn() net.Conn {
|
||||
return c.Conn
|
||||
}
|
||||
|
||||
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Client(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Server(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
c.init()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Dial(network, address string) (*Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
// Listen creates a plain listener config (no TLS detection).
|
||||
func Listen(network, address string) (*Listener, error) {
|
||||
ln, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
cfg.BaseTLSConfig = nil
|
||||
cfg.GetConfigForClient = nil
|
||||
return &Listener{Listener: ln, cfg: cfg}, nil
|
||||
}
|
||||
|
||||
// ListenWithConfig creates a listener with full config.
|
||||
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
|
||||
ln, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
||||
}
|
||||
|
||||
// ListenWithListenConfig creates listener using net.ListenConfig.
|
||||
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
|
||||
ln, err := lc.Listen(context.Background(), network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
||||
}
|
||||
|
||||
// ListenTLS creates TLS listener from cert/key paths.
|
||||
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = allowNonTLS
|
||||
cfg.BaseTLSConfig = TLSDefaults()
|
||||
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
|
||||
return ListenWithConfig(network, address, cfg)
|
||||
}
|
||||
|
||||
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
|
||||
out := DefaultListenerConfig()
|
||||
out.AllowNonTLS = cfg.AllowNonTLS
|
||||
out.SniffTimeout = cfg.SniffTimeout
|
||||
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
|
||||
out.BaseTLSConfig = cfg.BaseTLSConfig
|
||||
out.GetConfigForClient = cfg.GetConfigForClient
|
||||
out.Logger = cfg.Logger
|
||||
if out.MaxClientHelloBytes <= 0 {
|
||||
out.MaxClientHelloBytes = 64 * 1024
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetConfig atomically replaces listener config for new accepted connections.
|
||||
func (l *Listener) SetConfig(cfg ListenerConfig) {
|
||||
l.mu.Lock()
|
||||
l.cfg = normalizeConfig(cfg)
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
// Config returns a copy of current config.
|
||||
func (l *Listener) Config() ListenerConfig {
|
||||
l.mu.RLock()
|
||||
cfg := l.cfg
|
||||
l.mu.RUnlock()
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Stats returns current counters snapshot.
|
||||
func (l *Listener) Stats() StatsSnapshot {
|
||||
return l.stats.Snapshot()
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
raw, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.stats.incAccepted()
|
||||
|
||||
l.mu.RLock()
|
||||
cfg := l.cfg
|
||||
l.mu.RUnlock()
|
||||
|
||||
return newConn(raw, cfg, &l.stats), nil
|
||||
}
|
||||
|
||||
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
|
||||
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
|
||||
type result struct {
|
||||
c net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
c, err := l.Accept()
|
||||
ch <- result{c: c, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case r := <-ch:
|
||||
return r.c, r.err
|
||||
}
|
||||
}
|
||||
|
||||
// Dial creates a plain TCP starnet.Conn.
|
||||
func Dial(network, address string) (*Conn, error) {
|
||||
raw, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
cfg.BaseTLSConfig = nil
|
||||
cfg.GetConfigForClient = nil
|
||||
c := newConn(raw, cfg, nil)
|
||||
c.isTLS = false
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DialWithConfig dials with net.Dialer options.
|
||||
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: dc.Timeout,
|
||||
LocalAddr: dc.LocalAddr,
|
||||
}
|
||||
raw, err := d.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
c := newConn(raw, cfg, nil)
|
||||
c.isTLS = false
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DialTLSWithConfig creates a TLS client connection wrapper.
|
||||
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
|
||||
d := net.Dialer{Timeout: timeout}
|
||||
raw, err := d.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tc := tls.Client(raw, tlsCfg)
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
isTLS: false,
|
||||
tlsCfg: nil,
|
||||
tlsConn: nil,
|
||||
noTlsReader: conn,
|
||||
isOriginal: true,
|
||||
Conn: raw,
|
||||
plainConn: raw,
|
||||
isTLS: true,
|
||||
tlsConn: tc,
|
||||
hostname: "",
|
||||
initErr: nil,
|
||||
allowNonTLS: false,
|
||||
skipSniff: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
// DialTLS creates TLS client conn from cert/key paths.
|
||||
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientTlsConn(conn, tlsConfig)
|
||||
cfg := TLSDefaults()
|
||||
cfg.Certificates = []tls.Certificate{cert}
|
||||
return DialTLSWithConfig(network, address, cfg, 0)
|
||||
}
|
||||
|
||||
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
|
||||
if listener == nil {
|
||||
return nil, ErrNilConn
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: normalizeConfig(cfg),
|
||||
}, nil
|
||||
}
|
||||
|
||||
691
tlssniffer_test.go
Normal file
691
tlssniffer_test.go
Normal file
@ -0,0 +1,691 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------- cert helpers ----------
|
||||
|
||||
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
if err != nil {
|
||||
t.Fatalf("serial: %v", err)
|
||||
}
|
||||
|
||||
tpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "starnet-test",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
DNSNames: dnsNames,
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate: %v", err)
|
||||
}
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
|
||||
func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate {
|
||||
t.Helper()
|
||||
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("X509KeyPair: %v", err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) {
|
||||
t.Helper()
|
||||
|
||||
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
|
||||
|
||||
cf, err := os.CreateTemp("", "starnet-cert-*.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp cert: %v", err)
|
||||
}
|
||||
kf, err := os.CreateTemp("", "starnet-key-*.pem")
|
||||
if err != nil {
|
||||
_ = cf.Close()
|
||||
_ = os.Remove(cf.Name())
|
||||
t.Fatalf("CreateTemp key: %v", err)
|
||||
}
|
||||
|
||||
if _, err := cf.Write(certPEM); err != nil {
|
||||
t.Fatalf("write cert: %v", err)
|
||||
}
|
||||
if _, err := kf.Write(keyPEM); err != nil {
|
||||
t.Fatalf("write key: %v", err)
|
||||
}
|
||||
_ = cf.Close()
|
||||
_ = kf.Close()
|
||||
|
||||
return cf.Name(), kf.Name(), func() {
|
||||
_ = os.Remove(cf.Name())
|
||||
_ = os.Remove(kf.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- server helpers ----------
|
||||
|
||||
func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenWithConfig: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
close(stop)
|
||||
_ = ln.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return ln, ln.Addr().String(), cleanup
|
||||
}
|
||||
|
||||
// ---------- tests ----------
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
ln, err := Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
msg := []byte("x")
|
||||
if _, err := c.Write(msg); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
buf := make([]byte, 1)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenWithListenConfig(t *testing.T) {
|
||||
lc := net.ListenConfig{}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
|
||||
ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenWithListenConfig: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
msg := []byte("ok")
|
||||
if _, err := c.Write(msg); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerSetConfig(t *testing.T) {
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
|
||||
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
cfg2 := cfg
|
||||
cfg2.SniffTimeout = time.Second
|
||||
ln.SetConfig(cfg2)
|
||||
|
||||
got := ln.Config()
|
||||
if got.SniffTimeout != time.Second {
|
||||
t.Fatalf("SetConfig not applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainAllowed(t *testing.T) {
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
cfg.BaseTLSConfig = nil
|
||||
|
||||
_, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
c, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
msg := []byte("hello-plain")
|
||||
if _, err := c.Write(msg); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(c, buf); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if string(buf) != string(msg) {
|
||||
t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) {
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{cert}
|
||||
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = false
|
||||
cfg.BaseTLSConfig = base
|
||||
|
||||
_, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
c, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
_, _ = c.Write([]byte("plain"))
|
||||
_ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
b := make([]byte, 1)
|
||||
_, err = c.Read(b)
|
||||
if err == nil {
|
||||
t.Fatalf("expected read error due to non-tls rejection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakeAndEcho(t *testing.T) {
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{cert}
|
||||
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = false
|
||||
cfg.BaseTLSConfig = base
|
||||
|
||||
_, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("tls dial: %v", err)
|
||||
}
|
||||
defer tc.Close()
|
||||
|
||||
msg := []byte("hello-tls")
|
||||
if _, err := tc.Write(msg); err != nil {
|
||||
t.Fatalf("tls write: %v", err)
|
||||
}
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(tc, buf); err != nil {
|
||||
t.Fatalf("tls read: %v", err)
|
||||
}
|
||||
if string(buf) != string(msg) {
|
||||
t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicConfigBySNI(t *testing.T) {
|
||||
certA := genSelfSignedCert(t, "a.local")
|
||||
certB := genSelfSignedCert(t, "b.local")
|
||||
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{certA}
|
||||
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = false
|
||||
cfg.BaseTLSConfig = base
|
||||
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
|
||||
if host == "b.local" {
|
||||
b := TLSDefaults()
|
||||
b.Certificates = []tls.Certificate{certB}
|
||||
return b, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "b.local",
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("tls dial: %v", err)
|
||||
}
|
||||
defer tc.Close()
|
||||
|
||||
if !tc.ConnectionState().HandshakeComplete {
|
||||
t.Fatalf("handshake not complete")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigForClientError(t *testing.T) {
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{cert}
|
||||
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = false
|
||||
cfg.BaseTLSConfig = base
|
||||
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
_, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
_, err := tls.Dial("tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected tls dial failure due to selector error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptContextCancel(t *testing.T) {
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
|
||||
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = ln.AcceptContext(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("expected context timeout/cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerStats(t *testing.T) {
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.AllowNonTLS = true
|
||||
|
||||
ln, addr, cleanup := startEchoServer(t, cfg)
|
||||
defer cleanup()
|
||||
|
||||
c, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
_, _ = c.Write([]byte("x"))
|
||||
_ = c.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
s := ln.Stats()
|
||||
if s.Accepted == 0 {
|
||||
t.Fatalf("expected accepted > 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialAndDialWithConfig(t *testing.T) {
|
||||
nl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer nl.Close()
|
||||
|
||||
go func() {
|
||||
c, err := nl.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
c1, err := Dial("tcp", nl.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer c1.Close()
|
||||
|
||||
msg := []byte("abc")
|
||||
if _, err := c1.Write(msg); err != nil {
|
||||
t.Fatalf("c1 write: %v", err)
|
||||
}
|
||||
got := make([]byte, 3)
|
||||
if _, err := io.ReadFull(c1, got); err != nil {
|
||||
t.Fatalf("c1 read: %v", err)
|
||||
}
|
||||
|
||||
c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second})
|
||||
if err != nil {
|
||||
t.Fatalf("DialWithConfig: %v", err)
|
||||
}
|
||||
defer c2.Close()
|
||||
}
|
||||
|
||||
func TestListenTLS_FileAPI(t *testing.T) {
|
||||
certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
|
||||
defer cleanupFiles()
|
||||
|
||||
ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTLS: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("tls dial: %v", err)
|
||||
}
|
||||
defer tc.Close()
|
||||
|
||||
msg := []byte("hi")
|
||||
if _, err := tc.Write(msg); err != nil {
|
||||
t.Fatalf("tls write: %v", err)
|
||||
}
|
||||
out := make([]byte, 2)
|
||||
if _, err := io.ReadFull(tc, out); err != nil {
|
||||
t.Fatalf("tls read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTLSWithConfig(t *testing.T) {
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{cert}
|
||||
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.BaseTLSConfig = base
|
||||
cfg.AllowNonTLS = false
|
||||
|
||||
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
clientCfg := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "localhost",
|
||||
}
|
||||
c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("DialTLSWithConfig: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if !c.IsTLS() {
|
||||
t.Fatalf("expected IsTLS true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTLS_FileAPI(t *testing.T) {
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
base := TLSDefaults()
|
||||
base.Certificates = []tls.Certificate{cert}
|
||||
cfg := DefaultListenerConfig()
|
||||
cfg.BaseTLSConfig = base
|
||||
cfg.AllowNonTLS = false
|
||||
|
||||
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
_, _ = io.Copy(c, c)
|
||||
}()
|
||||
|
||||
clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
|
||||
defer cleanupFiles()
|
||||
|
||||
c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("DialTLS: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if !c.IsTLS() {
|
||||
t.Fatalf("expected IsTLS true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnIsTLS_PlainAndTLS(t *testing.T) {
|
||||
// ---- plain case ----
|
||||
plainCfg := DefaultListenerConfig()
|
||||
plainCfg.AllowNonTLS = true
|
||||
|
||||
ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen1: %v", err)
|
||||
}
|
||||
defer ln1.Close()
|
||||
|
||||
plainDone := make(chan *Conn, 1)
|
||||
plainErr := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
nc, err := ln1.Accept()
|
||||
if err != nil {
|
||||
plainErr <- err
|
||||
return
|
||||
}
|
||||
sc, ok := nc.(*Conn)
|
||||
if !ok {
|
||||
_ = nc.Close()
|
||||
plainErr <- errors.New("accepted conn is not *Conn")
|
||||
return
|
||||
}
|
||||
plainDone <- sc
|
||||
|
||||
// block until client sends one byte, then close
|
||||
buf := make([]byte, 1)
|
||||
_, _ = sc.Read(buf)
|
||||
_ = sc.Close()
|
||||
}()
|
||||
|
||||
c1, err := net.Dial("tcp", ln1.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial1: %v", err)
|
||||
}
|
||||
if _, err := c1.Write([]byte("p")); err != nil {
|
||||
_ = c1.Close()
|
||||
t.Fatalf("plain client write: %v", err)
|
||||
}
|
||||
_ = c1.Close()
|
||||
|
||||
select {
|
||||
case err := <-plainErr:
|
||||
t.Fatalf("plain server error: %v", err)
|
||||
case sc1 := <-plainDone:
|
||||
if sc1.IsTLS() {
|
||||
t.Fatalf("plain conn should not be TLS")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timeout waiting plain side")
|
||||
}
|
||||
|
||||
// ---- tls case ----
|
||||
cert := genSelfSignedCert(t, "localhost")
|
||||
tlsBase := TLSDefaults()
|
||||
tlsBase.Certificates = []tls.Certificate{cert}
|
||||
|
||||
tlsCfg := DefaultListenerConfig()
|
||||
tlsCfg.BaseTLSConfig = tlsBase
|
||||
tlsCfg.AllowNonTLS = false
|
||||
|
||||
ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("listen2: %v", err)
|
||||
}
|
||||
defer ln2.Close()
|
||||
|
||||
tlsDone := make(chan *Conn, 1)
|
||||
tlsErr := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
nc, err := ln2.Accept()
|
||||
if err != nil {
|
||||
tlsErr <- err
|
||||
return
|
||||
}
|
||||
sc, ok := nc.(*Conn)
|
||||
if !ok {
|
||||
_ = nc.Close()
|
||||
tlsErr <- errors.New("accepted conn is not *Conn")
|
||||
return
|
||||
}
|
||||
tlsDone <- sc
|
||||
|
||||
// key point: wait for real data to ensure TLS handshake/path is executed
|
||||
buf := make([]byte, 1)
|
||||
_, _ = sc.Read(buf)
|
||||
_ = sc.Close()
|
||||
}()
|
||||
|
||||
d := &net.Dialer{Timeout: 2 * time.Second}
|
||||
tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true, // test only
|
||||
ServerName: "localhost",
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("tls dial: %v", err)
|
||||
}
|
||||
if _, err := tc.Write([]byte("t")); err != nil {
|
||||
_ = tc.Close()
|
||||
t.Fatalf("tls client write: %v", err)
|
||||
}
|
||||
_ = tc.Close()
|
||||
|
||||
select {
|
||||
case err := <-tlsErr:
|
||||
t.Fatalf("tls server error: %v", err)
|
||||
case sc2 := <-tlsDone:
|
||||
if !sc2.IsTLS() {
|
||||
t.Fatalf("tls conn should be TLS")
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("timeout waiting tls side")
|
||||
}
|
||||
}
|
||||
43
tlsstats.go
Normal file
43
tlsstats.go
Normal file
@ -0,0 +1,43 @@
|
||||
package starnet
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// StatsSnapshot is a read-only copy of runtime counters.
|
||||
type StatsSnapshot struct {
|
||||
Accepted uint64
|
||||
TLSDetected uint64
|
||||
PlainDetected uint64
|
||||
InitFailures uint64
|
||||
Closed uint64
|
||||
}
|
||||
|
||||
// Stats provides lock-free counters.
|
||||
type Stats struct {
|
||||
accepted uint64
|
||||
tlsDetected uint64
|
||||
plainDetected uint64
|
||||
initFailures uint64
|
||||
closed uint64
|
||||
}
|
||||
|
||||
func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) }
|
||||
func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
|
||||
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
|
||||
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
|
||||
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
|
||||
|
||||
// Snapshot returns a stable view of counters.
|
||||
func (s *Stats) Snapshot() StatsSnapshot {
|
||||
return StatsSnapshot{
|
||||
Accepted: atomic.LoadUint64(&s.accepted),
|
||||
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
|
||||
PlainDetected: atomic.LoadUint64(&s.plainDetected),
|
||||
InitFailures: atomic.LoadUint64(&s.initFailures),
|
||||
Closed: atomic.LoadUint64(&s.closed),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is a minimal logging abstraction.
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
97
transport.go
Normal file
97
transport.go
Normal file
@ -0,0 +1,97 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transport 自定义 Transport(支持请求级配置)
|
||||
type Transport struct {
|
||||
base *http.Transport
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RoundTrip 实现 http.RoundTripper 接口
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// 确保 base 已初始化
|
||||
if t.base == nil {
|
||||
t.mu.Lock()
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// 提取请求级别的配置
|
||||
reqCtx := getRequestContext(req.Context())
|
||||
|
||||
// 优先级1:完全自定义的 transport
|
||||
if reqCtx.Transport != nil {
|
||||
return reqCtx.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// 优先级2:需要动态配置
|
||||
if needsDynamicTransport(reqCtx) {
|
||||
dynamicTransport := t.buildDynamicTransport(reqCtx)
|
||||
return dynamicTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// 优先级3:使用基础 transport
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// buildDynamicTransport 构建动态 Transport
|
||||
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
|
||||
t.mu.RLock()
|
||||
transport := t.base.Clone()
|
||||
t.mu.RUnlock()
|
||||
|
||||
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
||||
if rc.TLSConfig != nil {
|
||||
transport.TLSClientConfig = rc.TLSConfig
|
||||
}
|
||||
|
||||
// 应用代理配置
|
||||
if rc.Proxy != "" {
|
||||
proxyURL, err := url.Parse(rc.Proxy)
|
||||
if err == nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用自定义 Dial 函数
|
||||
if rc.DialFn != nil {
|
||||
transport.DialContext = rc.DialFn
|
||||
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
||||
// 使用默认 Dial 函数(会从 context 读取配置)
|
||||
transport.DialContext = defaultDialFunc
|
||||
transport.DialTLSContext = defaultDialTLSFunc
|
||||
}
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
// Base 获取基础 Transport
|
||||
func (t *Transport) Base() *http.Transport {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.base
|
||||
}
|
||||
|
||||
// SetBase 设置基础 Transport
|
||||
func (t *Transport) SetBase(base *http.Transport) {
|
||||
t.mu.Lock()
|
||||
t.base = base
|
||||
t.mu.Unlock()
|
||||
}
|
||||
131
types.go
Normal file
131
types.go
Normal file
@ -0,0 +1,131 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP Content-Type 常量
|
||||
const (
|
||||
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
|
||||
ContentTypeFormData = "multipart/form-data"
|
||||
ContentTypeJSON = "application/json"
|
||||
ContentTypeXML = "application/xml"
|
||||
ContentTypePlain = "text/plain"
|
||||
ContentTypeHTML = "text/html"
|
||||
ContentTypeOctetStream = "application/octet-stream"
|
||||
)
|
||||
|
||||
// 默认配置
|
||||
const (
|
||||
DefaultDialTimeout = 5 * time.Second
|
||||
DefaultTimeout = 10 * time.Second
|
||||
DefaultUserAgent = "Starnet/1.0.0"
|
||||
DefaultFetchRespBody = false
|
||||
)
|
||||
|
||||
// RequestFile 表示要上传的文件
|
||||
type RequestFile struct {
|
||||
FormName string // 表单字段名
|
||||
FileName string // 文件名
|
||||
FilePath string // 文件路径(如果从文件读取)
|
||||
FileData io.Reader // 文件数据流
|
||||
FileSize int64 // 文件大小
|
||||
FileType string // MIME 类型
|
||||
}
|
||||
|
||||
// UploadProgressFunc 文件上传进度回调函数
|
||||
type UploadProgressFunc func(filename string, uploaded int64, total int64)
|
||||
|
||||
// NetworkConfig 网络配置
|
||||
type NetworkConfig struct {
|
||||
Proxy string // 代理地址
|
||||
DialTimeout time.Duration // 连接超时
|
||||
Timeout time.Duration // 总超时
|
||||
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// TLSConfig TLS 配置
|
||||
type TLSConfig struct {
|
||||
Config *tls.Config // TLS 配置
|
||||
SkipVerify bool // 跳过证书验证
|
||||
}
|
||||
|
||||
// DNSConfig DNS 配置
|
||||
type DNSConfig struct {
|
||||
CustomIP []string // 直接指定 IP(最高优先级)
|
||||
CustomDNS []string // 自定义 DNS 服务器
|
||||
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
|
||||
}
|
||||
|
||||
// BodyConfig 请求体配置
|
||||
type BodyConfig struct {
|
||||
Bytes []byte // 原始字节
|
||||
Reader io.Reader // 数据流
|
||||
FormData map[string][]string // 表单数据
|
||||
Files []RequestFile // 文件列表
|
||||
}
|
||||
|
||||
// RequestConfig 请求配置(内部使用)
|
||||
type RequestConfig struct {
|
||||
Network NetworkConfig
|
||||
TLS TLSConfig
|
||||
DNS DNSConfig
|
||||
Body BodyConfig
|
||||
Headers http.Header
|
||||
Cookies []*http.Cookie
|
||||
Queries map[string][]string
|
||||
|
||||
// 其他配置
|
||||
BasicAuth [2]string // Basic 认证
|
||||
ContentLength int64 // 手动设置的 Content-Length
|
||||
AutoCalcContentLength bool // 自动计算 Content-Length
|
||||
UploadProgress UploadProgressFunc // 上传进度回调
|
||||
|
||||
// Transport 配置
|
||||
CustomTransport bool // 是否使用自定义 Transport
|
||||
Transport *http.Transport // 自定义 Transport
|
||||
}
|
||||
|
||||
// Clone 克隆配置
|
||||
func (c *RequestConfig) Clone() *RequestConfig {
|
||||
return &RequestConfig{
|
||||
Network: NetworkConfig{
|
||||
Proxy: c.Network.Proxy,
|
||||
DialTimeout: c.Network.DialTimeout,
|
||||
Timeout: c.Network.Timeout,
|
||||
DialFunc: c.Network.DialFunc,
|
||||
},
|
||||
TLS: TLSConfig{
|
||||
Config: cloneTLSConfig(c.TLS.Config),
|
||||
SkipVerify: c.TLS.SkipVerify,
|
||||
},
|
||||
DNS: DNSConfig{
|
||||
CustomIP: cloneStringSlice(c.DNS.CustomIP),
|
||||
CustomDNS: cloneStringSlice(c.DNS.CustomDNS),
|
||||
LookupFunc: c.DNS.LookupFunc,
|
||||
},
|
||||
Body: BodyConfig{
|
||||
Bytes: cloneBytes(c.Body.Bytes),
|
||||
Reader: c.Body.Reader, // Reader 不可克隆
|
||||
FormData: cloneStringMapSlice(c.Body.FormData),
|
||||
Files: cloneFiles(c.Body.Files),
|
||||
},
|
||||
Headers: cloneHeader(c.Headers),
|
||||
Cookies: cloneCookies(c.Cookies),
|
||||
Queries: cloneStringMapSlice(c.Queries),
|
||||
BasicAuth: c.BasicAuth,
|
||||
ContentLength: c.ContentLength,
|
||||
AutoCalcContentLength: c.AutoCalcContentLength,
|
||||
UploadProgress: c.UploadProgress,
|
||||
CustomTransport: c.CustomTransport,
|
||||
Transport: c.Transport, // Transport 共享
|
||||
}
|
||||
}
|
||||
|
||||
// RequestOpt 请求选项函数
|
||||
type RequestOpt func(*Request) error
|
||||
212
utils.go
Normal file
212
utils.go
Normal file
@ -0,0 +1,212 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// validMethod 验证 HTTP 方法是否有效
|
||||
func validMethod(method string) bool {
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// isNotToken 检查字符是否不是 token 字符
|
||||
func isNotToken(r rune) bool {
|
||||
return !isTokenRune(r)
|
||||
}
|
||||
|
||||
// isTokenRune 检查字符是否是 token 字符
|
||||
func isTokenRune(r rune) bool {
|
||||
i := int(r)
|
||||
return i < 127 && isTokenTable[i]
|
||||
}
|
||||
|
||||
// isTokenTable token 字符表
|
||||
var isTokenTable = [127]bool{
|
||||
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
|
||||
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
|
||||
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
|
||||
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
|
||||
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
|
||||
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
|
||||
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
|
||||
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
|
||||
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
|
||||
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
|
||||
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
|
||||
}
|
||||
|
||||
// hasPort 检查地址是否包含端口
|
||||
func hasPort(s string) bool {
|
||||
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
|
||||
}
|
||||
|
||||
// removeEmptyPort 移除空端口
|
||||
func removeEmptyPort(host string) string {
|
||||
if hasPort(host) {
|
||||
return strings.TrimSuffix(host, ":")
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// UrlEncode URL 编码
|
||||
func UrlEncode(str string) string {
|
||||
return url.QueryEscape(str)
|
||||
}
|
||||
|
||||
// UrlEncodeRaw URL 编码(空格编码为 %20)
|
||||
func UrlEncodeRaw(str string) string {
|
||||
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
||||
}
|
||||
|
||||
// UrlDecode URL 解码
|
||||
func UrlDecode(str string) (string, error) {
|
||||
return url.QueryUnescape(str)
|
||||
}
|
||||
|
||||
// BuildQuery 构建查询字符串
|
||||
func BuildQuery(data map[string]string) string {
|
||||
query := url.Values{}
|
||||
for k, v := range data {
|
||||
query.Add(k, v)
|
||||
}
|
||||
return query.Encode()
|
||||
}
|
||||
|
||||
// BuildPostForm 构建 POST 表单数据
|
||||
func BuildPostForm(data map[string]string) []byte {
|
||||
return []byte(BuildQuery(data))
|
||||
}
|
||||
|
||||
// cloneHeader 克隆 Header
|
||||
func cloneHeader(h http.Header) http.Header {
|
||||
if h == nil {
|
||||
return make(http.Header)
|
||||
}
|
||||
newHeader := make(http.Header, len(h))
|
||||
for k, v := range h {
|
||||
newHeader[k] = append([]string(nil), v...)
|
||||
}
|
||||
return newHeader
|
||||
}
|
||||
|
||||
// cloneCookies 克隆 Cookies
|
||||
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
|
||||
if cookies == nil {
|
||||
return nil
|
||||
}
|
||||
newCookies := make([]*http.Cookie, len(cookies))
|
||||
for i, c := range cookies {
|
||||
newCookies[i] = &http.Cookie{
|
||||
Name: c.Name,
|
||||
Value: c.Value,
|
||||
Path: c.Path,
|
||||
Domain: c.Domain,
|
||||
Expires: c.Expires,
|
||||
RawExpires: c.RawExpires,
|
||||
MaxAge: c.MaxAge,
|
||||
Secure: c.Secure,
|
||||
HttpOnly: c.HttpOnly,
|
||||
SameSite: c.SameSite,
|
||||
Raw: c.Raw,
|
||||
Unparsed: append([]string(nil), c.Unparsed...),
|
||||
}
|
||||
}
|
||||
return newCookies
|
||||
}
|
||||
|
||||
// cloneStringMapSlice 克隆 map[string][]string
|
||||
func cloneStringMapSlice(m map[string][]string) map[string][]string {
|
||||
if m == nil {
|
||||
return make(map[string][]string)
|
||||
}
|
||||
newMap := make(map[string][]string, len(m))
|
||||
for k, v := range m {
|
||||
newMap[k] = append([]string(nil), v...)
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
// cloneBytes 克隆字节切片
|
||||
func cloneBytes(b []byte) []byte {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
newBytes := make([]byte, len(b))
|
||||
copy(newBytes, b)
|
||||
return newBytes
|
||||
}
|
||||
|
||||
// cloneStringSlice 克隆字符串切片
|
||||
func cloneStringSlice(s []string) []string {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
newSlice := make([]string, len(s))
|
||||
copy(newSlice, s)
|
||||
return newSlice
|
||||
}
|
||||
|
||||
// cloneFiles 克隆文件列表
|
||||
func cloneFiles(files []RequestFile) []RequestFile {
|
||||
if files == nil {
|
||||
return nil
|
||||
}
|
||||
newFiles := make([]RequestFile, len(files))
|
||||
copy(newFiles, files)
|
||||
return newFiles
|
||||
}
|
||||
|
||||
// cloneTLSConfig 克隆 TLS 配置
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return cfg.Clone()
|
||||
}
|
||||
|
||||
// copyWithProgress 带进度的复制
|
||||
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
|
||||
if progress == nil {
|
||||
return io.Copy(dst, src)
|
||||
}
|
||||
|
||||
var written int64
|
||||
buf := make([]byte, 32*1024) // 32KB buffer
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return written, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
nr, err := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
// 同步调用进度回调(不使用 goroutine)
|
||||
progress(filename, written, total)
|
||||
}
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// 最后一次进度回调
|
||||
progress(filename, written, total)
|
||||
return written, nil
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
284
utils_test.go
Normal file
284
utils_test.go
Normal file
@ -0,0 +1,284 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUrlEncodeRaw(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic string with space",
|
||||
input: "hello world",
|
||||
expected: "hello%20world",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "hello world!@#$%^&*()_+-=~`",
|
||||
expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "chinese characters",
|
||||
input: "你好世界",
|
||||
expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UrlEncodeRaw(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "space encoded as plus",
|
||||
input: "hello world",
|
||||
expected: "hello+world",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "hello world!@#$%^&*()_+-=~`",
|
||||
expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UrlEncode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUrlDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic decode",
|
||||
input: "hello%20world",
|
||||
expected: "hello world",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "plus to space",
|
||||
input: "hello+world",
|
||||
expected: "hello world",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60",
|
||||
expected: "hello world!@#$%^&*()_+-=~`",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid encoding",
|
||||
input: "%zz",
|
||||
expected: "",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := UrlDecode(tt.input)
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("UrlDecode(%q) expected error, got nil", tt.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single parameter",
|
||||
input: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
expected: "key=value",
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "nil map",
|
||||
input: nil,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := BuildQuery(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPostForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]string
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "basic form",
|
||||
input: map[string]string{
|
||||
"key1": "value1",
|
||||
},
|
||||
expected: []byte("key1=value1"),
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
input: map[string]string{},
|
||||
expected: []byte(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := BuildPostForm(tt.input)
|
||||
if string(result) != string(tt.expected) {
|
||||
t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
expected bool
|
||||
}{
|
||||
{"GET", "GET", true},
|
||||
{"POST", "POST", true},
|
||||
{"PUT", "PUT", true},
|
||||
{"DELETE", "DELETE", true},
|
||||
{"PATCH", "PATCH", true},
|
||||
{"OPTIONS", "OPTIONS", true},
|
||||
{"HEAD", "HEAD", true},
|
||||
{"TRACE", "TRACE", true},
|
||||
{"CONNECT", "CONNECT", true},
|
||||
{"invalid with space", "GET POST", false},
|
||||
{"invalid with special char", "GET<>", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validMethod(tt.method)
|
||||
if result != tt.expected {
|
||||
t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneCookies_FullFields(t *testing.T) {
|
||||
expire := time.Now().Add(2 * time.Hour)
|
||||
|
||||
src := []*http.Cookie{
|
||||
{
|
||||
Name: "sid",
|
||||
Value: "abc123",
|
||||
Path: "/",
|
||||
Domain: "example.com",
|
||||
Expires: expire,
|
||||
RawExpires: expire.UTC().Format(time.RFC1123),
|
||||
MaxAge: 3600,
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Raw: "sid=abc123; Path=/; HttpOnly",
|
||||
Unparsed: []string{"Priority=High", "Partitioned"},
|
||||
},
|
||||
}
|
||||
|
||||
got := cloneCookies(src)
|
||||
if got == nil || len(got) != 1 {
|
||||
t.Fatalf("cloneCookies() len=%v; want 1", len(got))
|
||||
}
|
||||
|
||||
// 指针应不同(不是浅拷贝)
|
||||
if got[0] == src[0] {
|
||||
t.Fatal("cookie pointer should be different (deep copy expected)")
|
||||
}
|
||||
|
||||
// 字段值应一致
|
||||
s := src[0]
|
||||
g := got[0]
|
||||
if g.Name != s.Name ||
|
||||
g.Value != s.Value ||
|
||||
g.Path != s.Path ||
|
||||
g.Domain != s.Domain ||
|
||||
!g.Expires.Equal(s.Expires) ||
|
||||
g.RawExpires != s.RawExpires ||
|
||||
g.MaxAge != s.MaxAge ||
|
||||
g.Secure != s.Secure ||
|
||||
g.HttpOnly != s.HttpOnly ||
|
||||
g.SameSite != s.SameSite ||
|
||||
g.Raw != s.Raw {
|
||||
t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s)
|
||||
}
|
||||
|
||||
// Unparsed 内容一致
|
||||
if len(g.Unparsed) != len(s.Unparsed) {
|
||||
t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed))
|
||||
}
|
||||
for i := range s.Unparsed {
|
||||
if g.Unparsed[i] != s.Unparsed[i] {
|
||||
t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i])
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 Unparsed 是深拷贝(修改源不影响目标)
|
||||
src[0].Unparsed[0] = "Modified=Yes"
|
||||
if got[0].Unparsed[0] == "Modified=Yes" {
|
||||
t.Fatal("Unparsed should be deep-copied, but was affected by source mutation")
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user