Compare commits

..

22 Commits

Author SHA1 Message Date
4568e17f06
fix: 修复核心bug并完善API
- 修复NewRequest系列函数不返回opt错误的问题
- 修复prepare()幂等性问题,支持请求重试
- 修复defaultDialTLSFunc的ServerName解析错误
- 修复Client.Clone()并发安全问题
- 补齐Client.Trace/Connect方法
- 新增Request.HTTPClient/Client方法
- 增强NewSimpleRequest错误处理的健壮性
2026-03-10 19:55:37 +08:00
1bb30514ec
bug fix:tls自定义时,没有设置servername的问题 2026-03-08 21:38:45 +08:00
50aef48d49
rewrite program 2026-03-08 20:19:40 +08:00
0e2f91eee2
fix:使用Client时,设置的参数不生效 2025-10-14 10:08:53 +08:00
b90c59d6e7
修改版本号 2025-08-21 21:40:29 +08:00
4e154cc17b
update benchmark 2025-08-21 21:37:21 +08:00
67b0025f9c
更新content-length的默认处理方式 2025-08-21 19:17:19 +08:00
c4fa62536a
为client新增部分函数 2025-08-21 15:32:19 +08:00
260ceb90ed
重构http Client部分 2025-08-21 15:02:02 +08:00
d260181adf
update 2025-08-15 15:07:51 +08:00
e3b7369e12
bug fix:nil pointer error 2025-08-13 10:16:08 +08:00
4e17fee681
bug fix 2025-07-14 18:38:31 +08:00
a8eed30db5
add http client control 2025-07-14 18:23:14 +08:00
c1eaf43058 update 2025-06-17 12:36:57 +08:00
9f5aca124d update 2025-06-17 12:09:12 +08:00
54958724e7 bug fix 2025-06-13 17:16:38 +08:00
7a17672149 update tls sniffer 2025-06-12 16:50:47 +08:00
44b807d3d1 update 2025-06-06 15:43:38 +08:00
0d847462b3 bug fix:nil pointer 2025-04-28 13:19:45 +08:00
deed4207ea bug fix 2024-08-30 23:44:49 +08:00
f6363fed07 move starqueue from starnet to stario 2024-08-18 17:18:52 +08:00
1de78f2f06 rewrite curl.go 2024-08-08 22:03:10 +08:00
42 changed files with 8954 additions and 879 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
.idea

1657
addon_test.go Normal file

File diff suppressed because it is too large Load Diff

197
benchmark_test.go Normal file
View File

@ -0,0 +1,197 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func BenchmarkGetRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkGetRequestWithHeaders(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL,
WithHeader("X-Custom", "value"),
WithUserAgent("BenchmarkAgent"))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkPostRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
testData := []byte("test data for benchmark")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithBody(testData))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkJSONRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type TestData struct {
Name string `json:"name"`
Value int `json:"value"`
}
data := TestData{Name: "test", Value: 123}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithJSON(data))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
var result map[string]string
resp.Body().JSON(&result)
resp.Close()
}
}
func BenchmarkConcurrentRequests(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
})
}
func BenchmarkRequestClone(b *testing.B) {
req := NewSimpleRequest("https://example.com", "GET").
SetHeader("X-Custom", "value").
AddQuery("key", "value")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = req.Clone()
}
}
func BenchmarkClientCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewClientNoErr()
}
}
func BenchmarkRequestCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewSimpleRequest("https://example.com", "GET")
}
}
func BenchmarkResponseBodyRead(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test response data"))
}))
defer server.Close()
// Pre-fetch response
resp, _ := Get(server.URL, WithAutoFetch(true))
defer resp.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = resp.Body().String()
}
}
func BenchmarkDifferentResponseSizes(b *testing.B) {
sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
for _, size := range sizes {
responseData := make([]byte, size)
for i := 0; i < size; i++ {
responseData[i] = 'A'
}
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(responseData)
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().Bytes()
resp.Close()
}
})
}
}

145
body_test.go Normal file
View File

@ -0,0 +1,145 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRequestBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := []byte("test data")
req := NewSimpleRequest(server.URL, "POST").SetBody(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().Bytes()
if !bytes.Equal(body, testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyString(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test string data"
req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyReader(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test reader data"
reader := strings.NewReader(testData)
req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Type") != ContentTypeJSON {
t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON)
}
var data map[string]string
json.NewDecoder(r.Body).Decode(&data)
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
testData := map[string]string{
"name": "John",
"email": "john@example.com",
}
req := NewSimpleRequest(server.URL, "POST").SetJSON(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != testData["name"] {
t.Errorf("name = %v; want %v", result["name"], testData["name"])
}
}
func TestRequestFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
data := make(map[string]string)
for k, v := range r.Form {
if len(v) > 0 {
data[k] = v[0]
}
}
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("email", "john@example.com")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != "John" {
t.Errorf("name = %v; want John", result["name"])
}
if result["email"] != "john@example.com" {
t.Errorf("email = %v; want john@example.com", result["email"])
}
}

345
client.go Normal file
View File

@ -0,0 +1,345 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
)
// Client HTTP 客户端封装
type Client struct {
client *http.Client
opts []RequestOpt
mu sync.RWMutex
}
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
//Timeout: DefaultTimeout,
}
// 应用选项(如果有)
if len(opts) > 0 {
// 创建临时请求以应用选项
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
if err != nil {
return nil, wrapError(err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
if req.config.Network.Timeout > 0 {
httpClient.Timeout = req.config.Network.Timeout
}
*/
// 如果有自定义 Transport
if req.config.CustomTransport && req.config.Transport != nil {
httpClient.Transport = &Transport{base: req.config.Transport}
}
}
return &Client{
client: httpClient,
opts: opts,
}, nil
}
// NewClientNoErr 创建新的 Client忽略错误
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
client = &Client{
client: &http.Client{},
opts: opts,
}
}
return client
}
// NewClientFromHTTP 从 http.Client 创建 Client
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
if httpClient == nil {
return nil, ErrNilClient
}
// 确保 Transport 是我们的自定义类型
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
// 已经是我们的类型
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
// 包装标准 Transport
httpClient.Transport = &Transport{
base: t,
}
default:
return nil, fmt.Errorf("unsupported transport type: %T", t)
}
}
return &Client{
client: httpClient,
}, nil
}
// HTTPClient 获取底层 http.Client
func (c *Client) HTTPClient() *http.Client {
return c.client
}
// RequestOptions 获取默认选项(返回副本)
func (c *Client) RequestOptions() []RequestOpt {
c.mu.RLock()
defer c.mu.RUnlock()
opts := make([]RequestOpt, len(c.opts))
copy(opts, c.opts)
return opts
}
// SetOptions 设置默认选项
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = opts
c.mu.Unlock()
return c
}
// AddOptions 追加默认选项
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = append(c.opts, opts...)
c.mu.Unlock()
return c
}
// Clone 克隆 Client深拷贝
func (c *Client) Clone() *Client {
c.mu.RLock()
defer c.mu.RUnlock()
// 克隆 Transport
var transport http.RoundTripper
if c.client.Transport != nil {
switch t := c.client.Transport.(type) {
case *Transport:
transport = &Transport{
base: t.base.Clone(),
}
case *http.Transport:
transport = t.Clone()
default:
transport = c.client.Transport
}
}
return &Client{
client: &http.Client{
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
},
opts: append([]RequestOpt(nil), c.opts...),
}
}
// SetDefaultTLSConfig 设置默认 TLS 配置
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if tlsConfig != nil {
transport.base.TLSClientConfig = tlsConfig.Clone()
} else {
transport.base.TLSClientConfig = nil
}
transport.mu.Unlock()
}
return c
}
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
} else {
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.mu.Unlock()
}
return c
}
// DisableRedirect 禁用重定向
func (c *Client) DisableRedirect() *Client {
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return c
}
// EnableRedirect 启用重定向
func (c *Client) EnableRedirect() *Client {
c.client.CheckRedirect = nil
return c
}
// NewRequest 创建新请求
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
return c.NewRequestWithContext(context.Background(), url, method, opts...)
}
// NewRequestWithContext 创建新请求(带 context
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
// 合并 Client 级别和请求级别的选项
c.mu.RLock()
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
c.mu.RUnlock()
req, err := newRequest(ctx, url, method, allOpts...)
if err != nil {
return nil, err
}
req.client = c
req.httpClient = c.client
return req, nil
}
// Get 发送 GET 请求
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodGet, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Post 发送 POST 请求
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPost, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Put 发送 PUT 请求
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPut, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Delete 发送 DELETE 请求
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodDelete, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Head 发送 HEAD 请求
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodHead, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Patch 发送 PATCH 请求
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPatch, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Options 发送 OPTIONS 请求
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodOptions, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
if err != nil {
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
client: c,
httpClient: c.client,
autoFetch: DefaultFetchRespBody,
}
}
return req
}
// Trace 发送 TRACE 请求
func (c *Client) Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求
func (c *Client) Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}

223
client_test.go Normal file
View File

@ -0,0 +1,223 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
if client == nil {
t.Fatal("NewClient() returned nil")
}
}
func TestNewClientNoErr(t *testing.T) {
client := NewClientNoErr()
if client == nil {
t.Fatal("NewClientNoErr() returned nil")
}
}
func TestNewClientFromHTTP(t *testing.T) {
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
client, err := NewClientFromHTTP(httpClient)
if err != nil {
t.Fatalf("NewClientFromHTTP() error: %v", err)
}
if client == nil {
t.Fatal("NewClientFromHTTP() returned nil")
}
// Test with nil client
_, err = NewClientFromHTTP(nil)
if err == nil {
t.Error("NewClientFromHTTP(nil) should return error")
}
}
func TestClientOptions(t *testing.T) {
client := NewClientNoErr()
// Set options
client.SetOptions(WithTimeout(5 * time.Second))
opts := client.RequestOptions()
if len(opts) != 1 {
t.Errorf("RequestOptions() length = %v; want 1", len(opts))
}
// Add options
client.AddOptions(WithUserAgent("TestAgent"))
opts = client.RequestOptions()
if len(opts) != 2 {
t.Errorf("RequestOptions() length = %v; want 2", len(opts))
}
}
func TestClientClone(t *testing.T) {
client := NewClientNoErr(WithTimeout(5 * time.Second))
cloned := client.Clone()
if cloned == nil {
t.Fatal("Clone() returned nil")
}
// 修改克隆的 client
cloned.SetOptions(WithTimeout(10 * time.Second))
origOpts := client.RequestOptions()
clonedOpts := cloned.RequestOptions()
// 原 client 应该还是 1 个选项
if len(origOpts) != 1 {
t.Errorf("Original client options = %v; want 1", len(origOpts))
}
// 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖)
if len(clonedOpts) != 1 {
t.Errorf("Cloned client options = %v; want 1", len(clonedOpts))
}
}
func TestClientHTTPMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
client := NewClientNoErr()
tests := []struct {
name string
method func(string, ...RequestOpt) (*Response, error)
want string
}{
{"GET", client.Get, "GET"},
{"POST", client.Post, "POST"},
{"PUT", client.Put, "PUT"},
{"DELETE", client.Delete, "DELETE"},
{"PATCH", client.Patch, "PATCH"},
{"HEAD", client.Head, "HEAD"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.method(server.URL)
if err != nil {
t.Fatalf("%s() error: %v", tt.name, err)
}
defer resp.Close()
if tt.want != "HEAD" {
body, _ := resp.Body().String()
if body != tt.want {
t.Errorf("Body = %v; want %v", body, tt.want)
}
}
})
}
}
func TestClientRedirect(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 2 {
redirectCount++
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("final"))
}))
defer server.Close()
// Test with redirect enabled (default)
client := NewClientNoErr()
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
resp.Close()
if redirectCount != 2 {
t.Errorf("Redirect count = %v; want 2", redirectCount)
}
// Test with redirect disabled
redirectCount = 0
client.DisableRedirect()
resp2, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusFound {
t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound)
}
}
func TestClientTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail with self-signed cert)
client := NewClientNoErr()
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected TLS error with self-signed cert, got nil")
}
// With skip verify
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() with skip verify error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientNewSimpleRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v"))
if req == nil {
t.Fatal("NewSimpleRequest returned nil")
}
if req.Err() != nil {
t.Fatalf("NewSimpleRequest err: %v", req.Err())
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}

111
concurrent_test.go Normal file
View File

@ -0,0 +1,111 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestConcurrentRequests(t *testing.T) {
var counter int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&counter, 1)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
concurrency := 100
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
wg.Wait()
if atomic.LoadInt64(&counter) != int64(concurrency) {
t.Errorf("counter = %v; want %v", counter, concurrency)
}
}
func TestConcurrentClientModification(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
var wg sync.WaitGroup
wg.Add(200)
// 100 goroutines reading
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
// 100 goroutines modifying options
for i := 0; i < 100; i++ {
go func(i int) {
defer wg.Done()
if i%2 == 0 {
client.AddOptions(WithTimeout(5 * time.Second))
} else {
_ = client.RequestOptions()
}
}(i)
}
wg.Wait()
}
func TestConcurrentRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value")
var wg sync.WaitGroup
wg.Add(50)
for i := 0; i < 50; i++ {
go func(i int) {
defer wg.Done()
cloned := baseReq.Clone()
// 修复:使用有效的 header 值
cloned.SetHeader("X-Index", fmt.Sprintf("%d", i))
resp, err := cloned.Do()
if err != nil {
t.Errorf("Do() error: %v", err)
return
}
resp.Close()
}(i)
}
wg.Wait()
}

149
context.go Normal file
View File

@ -0,0 +1,149 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{}
if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
return rc
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
execCtx := ctx
// 处理 TLS 配置
var tlsConfig *tls.Config
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
}
// 注入代理
if config.Network.Proxy != "" {
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
}
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout 和 Timeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
if config.Network.Timeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
}
return execCtx
}

463
curl.go
View File

@ -1,463 +0,0 @@
package starnet
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"b612.me/stario"
)
const (
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
HEADER_FORM_DATA = `multipart/form-data`
HEADER_JSON = `application/json`
HEADER_PLAIN = `text/plain`
)
type RequestFile struct {
UploadFile string
UploadForm map[string]string
UploadName string
}
type Request struct {
Url string
RespURL string
Method string
RecvData []byte
RecvContentLength int64
RecvIo io.Writer
RespHeader http.Header
RespCookies []*http.Cookie
RespHttpCode int
Location *url.URL
CircleBuffer *stario.StarBuffer
respReader io.ReadCloser
respOrigin *http.Response
reqOrigin *http.Request
RequestOpts
}
type RequestOpts struct {
RequestFile
PostBuffer io.Reader
Process func(float64)
Proxy string
Timeout time.Duration
DialTimeout time.Duration
ReqHeader http.Header
ReqCookies []*http.Cookie
WriteRecvData bool
SkipTLSVerify bool
CustomTransport *http.Transport
Queries map[string]string
DisableRedirect bool
TlsConfig *tls.Config
}
type RequestOpt func(opt *RequestOpts)
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) {
opt.DialTimeout = timeout
}
}
func WithTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) {
opt.Timeout = timeout
}
}
func WithHeader(key, val string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set(key, val)
}
}
func WithTlsConfig(tlscfg *tls.Config) RequestOpt {
return func(opt *RequestOpts) {
opt.TlsConfig = tlscfg
}
}
func WithHeaderMap(header map[string]string) RequestOpt {
return func(opt *RequestOpts) {
for key, val := range header {
opt.ReqHeader.Set(key, val)
}
}
}
func WithHeaderAdd(key, val string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Add(key, val)
}
}
func WithReader(r io.Reader) RequestOpt {
return func(opt *RequestOpts) {
opt.PostBuffer = r
}
}
func WithFetchRespBody(fetch bool) RequestOpt {
return func(opt *RequestOpts) {
opt.WriteRecvData = fetch
}
}
func WithCookies(ck []*http.Cookie) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqCookies = ck
}
}
func WithCookie(key, val, path string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
}
}
func WithCookieMap(header map[string]string, path string) RequestOpt {
return func(opt *RequestOpts) {
for key, val := range header {
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
}
}
}
func WithQueries(queries map[string]string) RequestOpt {
return func(opt *RequestOpts) {
opt.Queries = queries
}
}
func WithProxy(proxy string) RequestOpt {
return func(opt *RequestOpts) {
opt.Proxy = proxy
}
}
func WithProcess(fn func(float64)) RequestOpt {
return func(opt *RequestOpts) {
opt.Process = fn
}
}
func WithContentType(ct string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set("Content-Type", ct)
}
}
func WithUserAgent(ua string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set("User-Agent", ua)
}
}
func WithCustomTransport(hs *http.Transport) RequestOpt {
return func(opt *RequestOpts) {
opt.CustomTransport = hs
}
}
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(opt *RequestOpts) {
opt.SkipTLSVerify = skip
}
}
func WithDisableRedirect(disable bool) RequestOpt {
return func(opt *RequestOpts) {
opt.DisableRedirect = disable
}
}
func NewRequests(url string, rawdata []byte, method string, opts ...RequestOpt) Request {
req := Request{
RequestOpts: RequestOpts{
Timeout: 30 * time.Second,
DialTimeout: 15 * time.Second,
WriteRecvData: true,
},
Url: url,
Method: method,
}
if rawdata != nil {
req.PostBuffer = bytes.NewBuffer(rawdata)
}
req.ReqHeader = make(http.Header)
if strings.ToUpper(method) == "POST" {
req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE)
}
req.ReqHeader.Set("User-Agent", "B612 / 1.1.0")
for _, v := range opts {
v(&req.RequestOpts)
}
if req.CustomTransport == nil {
req.CustomTransport = &http.Transport{}
}
if req.SkipTLSVerify {
if req.CustomTransport.TLSClientConfig == nil {
req.CustomTransport.TLSClientConfig = &tls.Config{}
}
req.CustomTransport.TLSClientConfig.InsecureSkipVerify = true
}
if req.TlsConfig != nil {
req.CustomTransport.TLSClientConfig = req.TlsConfig
}
req.CustomTransport.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
c, err := net.DialTimeout(netw, addr, req.DialTimeout)
if err != nil {
return nil, err
}
if req.Timeout != 0 {
c.SetDeadline(time.Now().Add(req.Timeout))
}
return c, nil
}
return req
}
func (curl *Request) ResetReqHeader() {
curl.ReqHeader = make(http.Header)
}
func (curl *Request) ResetReqCookies() {
curl.ReqCookies = []*http.Cookie{}
}
func (curl *Request) AddSimpleCookie(key, value string) {
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"})
}
func (curl *Request) AddCookie(key, value, path string) {
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: path})
}
func randomBoundary() string {
var buf [30]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
func Curl(curl Request) (resps Request, err error) {
var fpsrc *os.File
if curl.RequestFile.UploadFile != "" {
fpsrc, err = os.Open(curl.UploadFile)
if err != nil {
return
}
defer fpsrc.Close()
boundary := randomBoundary()
boundarybytes := []byte("\r\n--" + boundary + "\r\n")
endbytes := []byte("\r\n--" + boundary + "--\r\n")
fpstat, _ := fpsrc.Stat()
filebig := float64(fpstat.Size())
sum, n := 0, 0
fpdst := stario.NewStarBuffer(1048576)
if curl.UploadForm != nil {
for k, v := range curl.UploadForm {
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\";\r\nContent-Type: x-www-form-urlencoded \r\n\r\n", k)
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
fpdst.Write([]byte(v))
}
}
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n", curl.UploadName, fpstat.Name())
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
go func() {
for {
bufs := make([]byte, 393213)
n, err = fpsrc.Read(bufs)
if err != nil {
if err == io.EOF {
if n != 0 {
fpdst.Write(bufs[0:n])
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
}
break
}
return
}
sum += n
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
fpdst.Write(bufs[0:n])
}
fpdst.Write(endbytes)
fpdst.Write(nil)
}()
curl.CircleBuffer = fpdst
curl.ReqHeader.Set("Content-Type", "multipart/form-data;boundary="+boundary)
}
req, resp, err := netcurl(curl)
if err != nil {
return Request{}, err
}
if resp.Request != nil && resp.Request.URL != nil {
curl.RespURL = resp.Request.URL.String()
}
curl.reqOrigin = req
curl.respOrigin = resp
curl.Location, _ = resp.Location()
curl.RespHttpCode = resp.StatusCode
curl.RespHeader = resp.Header
curl.RespCookies = resp.Cookies()
curl.RecvContentLength = resp.ContentLength
readFunc := func(reader io.ReadCloser, writer io.Writer) error {
lengthall := resp.ContentLength
defer reader.Close()
var lengthsum int
buf := make([]byte, 65535)
for {
n, err := reader.Read(buf)
if n != 0 {
_, err := writer.Write(buf[:n])
lengthsum += n
if curl.Process != nil {
go curl.Process(float64(lengthsum) / float64(lengthall) * 100.00)
}
if err != nil {
return err
}
}
if err != nil && err != io.EOF {
return err
} else if err == io.EOF {
return nil
}
}
}
if curl.WriteRecvData {
buf := bytes.NewBuffer([]byte{})
err = readFunc(resp.Body, buf)
if err != nil {
return
}
curl.RecvData = buf.Bytes()
} else {
curl.respReader = resp.Body
}
if curl.RecvIo != nil {
if curl.WriteRecvData {
_, err = curl.RecvIo.Write(curl.RecvData)
} else {
err = readFunc(resp.Body, curl.RecvIo)
if err != nil {
return
}
}
}
return curl, err
}
// RespBodyReader Only works when WriteRecvData set to false
func (curl *Request) RespBodyReader() io.ReadCloser {
return curl.respReader
}
func netcurl(curl Request) (*http.Request, *http.Response, error) {
var req *http.Request
var err error
if curl.Method == "" {
return nil, nil, errors.New("Error Method Not Entered")
}
if curl.PostBuffer != nil {
req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer)
} else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 {
req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer)
} else {
req, err = http.NewRequest(curl.Method, curl.Url, nil)
}
if curl.Queries != nil {
sid := req.URL.Query()
for k, v := range curl.Queries {
sid.Add(k, v)
}
req.URL.RawQuery = sid.Encode()
}
if err != nil {
return nil, nil, err
}
req.Header = curl.ReqHeader
if len(curl.ReqCookies) != 0 {
for _, v := range curl.ReqCookies {
req.AddCookie(v)
}
}
if curl.Proxy != "" {
purl, err := url.Parse(curl.Proxy)
if err != nil {
return nil, nil, err
}
curl.CustomTransport.Proxy = http.ProxyURL(purl)
}
client := &http.Client{
Transport: curl.CustomTransport,
}
if curl.DisableRedirect {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
resp, err := client.Do(req)
return req, resp, err
}
func UrlEncodeRaw(str string) string {
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
return strs
}
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
func BuildQuery(queryData map[string]string) string {
query := url.Values{}
for k, v := range queryData {
query.Add(k, v)
}
return query.Encode()
}
func BuildPostForm(queryMap map[string]string) []byte {
query := url.Values{}
for k, v := range queryMap {
query.Add(k, v)
}
return []byte(query.Encode())
}
func (r Request) Resopnse() *http.Response {
return r.respOrigin
}
func (r Request) Request() *http.Request {
return r.reqOrigin
}

147
defaults.go Normal file
View File

@ -0,0 +1,147 @@
package starnet
import (
"net/http"
"sync"
"time"
)
var (
defaultClient *Client
defaultHTTPClient *http.Client
defaultClientOnce sync.Once
defaultHTTPOnce sync.Once
defaultMu sync.RWMutex
)
// DefaultClient 获取默认 Client单例
func DefaultClient() *Client {
defaultMu.RLock()
if defaultClient != nil {
c := defaultClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultClientOnce.Do(func() {
c := NewClientNoErr()
defaultMu.Lock()
defaultClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultClient
defaultMu.RUnlock()
return c
}
// DefaultHTTPClient 获取默认 http.Client单例
func DefaultHTTPClient() *http.Client {
defaultMu.RLock()
if defaultHTTPClient != nil {
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultHTTPOnce.Do(func() {
c := &http.Client{
Transport: &Transport{
base: &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
},
Timeout: 0, // 由请求级控制超时
}
defaultMu.Lock()
defaultHTTPClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
// SetDefaultClient 设置默认 Client
func SetDefaultClient(client *Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultClient = client
// 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖
defaultClientOnce.Do(func() {})
}
// SetDefaultHTTPClient 设置默认 http.Client
func SetDefaultHTTPClient(client *http.Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultHTTPClient = client
// 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖
defaultHTTPOnce.Do(func() {})
}
// Get 发送 GET 请求(使用默认 Client
func Get(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Get(url, opts...)
}
// Post 发送 POST 请求(使用默认 Client
func Post(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Post(url, opts...)
}
// Put 发送 PUT 请求(使用默认 Client
func Put(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Put(url, opts...)
}
// Delete 发送 DELETE 请求(使用默认 Client
func Delete(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Delete(url, opts...)
}
// Head 发送 HEAD 请求(使用默认 Client
func Head(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Head(url, opts...)
}
// Patch 发送 PATCH 请求(使用默认 Client
func Patch(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Patch(url, opts...)
}
// Options 发送 OPTIONS 请求(使用默认 Client
func Options(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Options(url, opts...)
}
// Trace 发送 TRACE 请求(使用默认 Client
func Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求(使用默认 Client
func Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}

163
dialer.go Normal file
View File

@ -0,0 +1,163 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"
)
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
timeout := reqCtx.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
// 获取 IP 地址列表
var addrs []string
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
}
}
// 尝试连接所有地址
var lastErr error
for _, addr := range addrs {
conn, err := net.DialTimeout(network, addr, dialTimeout)
if err != nil {
lastErr = err
continue
}
// 设置总超时
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify自动设置
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
host = addr[:idx]
} else {
host = addr
}
}
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = host
}
// 执行 TLS 握手
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, wrapError(err, "tls handshake")
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/

103
dns_test.go Normal file
View File

@ -0,0 +1,103 @@
package starnet
import (
"context"
"net"
"testing"
)
func TestRequestCustomIP(t *testing.T) {
customIPs := []string{"1.2.3.4", "5.6.7.8"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP(customIPs)
if len(req.config.DNS.CustomIP) != 2 {
t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP))
}
for i, ip := range req.config.DNS.CustomIP {
if ip != customIPs[i] {
t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i])
}
}
}
func TestRequestCustomIPInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"invalid-ip"})
if req.Err() == nil {
t.Error("Expected error for invalid IP, got nil")
}
}
func TestRequestCustomDNS(t *testing.T) {
dnsServers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS(dnsServers)
if len(req.config.DNS.CustomDNS) != 2 {
t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS))
}
}
func TestRequestCustomDNSInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS([]string{"invalid-dns"})
if req.Err() == nil {
t.Error("Expected error for invalid DNS, got nil")
}
}
func TestRequestLookupFunc(t *testing.T) {
called := false
lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) {
called = true
return []net.IPAddr{
{IP: net.ParseIP("1.2.3.4")},
}, nil
}
req := NewSimpleRequest("http://example.com", "GET").
SetLookupFunc(lookupFunc)
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc not set")
}
// Call the function to verify it works
ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com")
if err != nil {
t.Errorf("LookupFunc error: %v", err)
}
if !called {
t.Error("LookupFunc was not called")
}
if len(ips) != 1 {
t.Errorf("IPs length = %v; want 1", len(ips))
}
}
func TestDNSPriority(t *testing.T) {
// CustomIP should have highest priority
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"1.2.3.4"}).
SetCustomDNS([]string{"8.8.8.8"}).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil
})
// CustomIP should be set
if len(req.config.DNS.CustomIP) == 0 {
t.Error("CustomIP should be set")
}
// Others should also be set (but CustomIP takes priority in actual use)
if len(req.config.DNS.CustomDNS) == 0 {
t.Error("CustomDNS should be set")
}
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc should be set")
}
}

58
errors.go Normal file
View File

@ -0,0 +1,58 @@
package starnet
import (
"errors"
"fmt"
)
var (
// ErrInvalidMethod 无效的 HTTP 方法
ErrInvalidMethod = errors.New("starnet: invalid HTTP method")
// ErrInvalidURL 无效的 URL
ErrInvalidURL = errors.New("starnet: invalid URL")
// ErrInvalidIP 无效的 IP 地址
ErrInvalidIP = errors.New("starnet: invalid IP address")
// ErrInvalidDNS 无效的 DNS 服务器
ErrInvalidDNS = errors.New("starnet: invalid DNS server")
// ErrNilClient HTTP Client 为 nil
ErrNilClient = errors.New("starnet: http client is nil")
// ErrNilReader Reader 为 nil
ErrNilReader = errors.New("starnet: reader is nil")
// ErrFileNotFound 文件不存在
ErrFileNotFound = errors.New("starnet: file not found")
// ErrRequestNotPrepared 请求未准备好
ErrRequestNotPrepared = errors.New("starnet: request not prepared")
// ErrBodyAlreadyConsumed Body 已被消费
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
)
// wrapError 包装错误,添加上下文信息
func wrapError(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
msg := fmt.Sprintf(format, args...)
return fmt.Errorf("%s: %w", msg, err)
}
var (
// ErrNilConn indicates a nil net.Conn argument.
ErrNilConn = errors.New("starnet: nil connection")
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
// ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP.
ErrNotTLS = errors.New("starnet: connection is not TLS")
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
)

200
example_test.go Normal file
View File

@ -0,0 +1,200 @@
package starnet_test
import (
"fmt"
"net/http"
"net/http/httptest"
"time"
"b612.me/starnet"
)
func ExampleGet() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Hello, World!
}
func ExamplePost() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Posted"))
}))
defer server.Close()
resp, err := starnet.Post(server.URL,
starnet.WithBodyString("test data"))
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Posted
}
func ExampleNewSimpleRequest() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
req := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom", "value").
AddQuery("name", "test")
resp, err := req.Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}
func ExampleClient_Get() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Client GET"))
}))
defer server.Close()
client := starnet.NewClientNoErr(
starnet.WithTimeout(10*time.Second),
starnet.WithUserAgent("MyApp/1.0"),
)
resp, err := client.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Client GET
}
func ExampleRequest_SetJSON() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type User struct {
Name string `json:"name"`
Email string `json:"email"`
}
user := User{Name: "John", Email: "john@example.com"}
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
SetJSON(user).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
fmt.Println(result["status"])
// Output: ok
}
func ExampleRequest_AddFormData() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
fmt.Fprintf(w, "name=%s", r.FormValue("name"))
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("age", "30").
Do()
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: name=John
}
func ExampleRequest_SetSkipTLSVerify() {
// This example shows how to skip TLS verification
// Useful for testing with self-signed certificates
req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET").
SetSkipTLSVerify(true)
// In a real scenario, you would call req.Do()
fmt.Println(req.Method())
// Output: GET
}
func ExampleRequest_Clone() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
baseReq := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-API-Key", "secret")
// Clone and modify
req1 := baseReq.Clone().AddQuery("page", "1")
req2 := baseReq.Clone().AddQuery("page", "2")
resp1, _ := req1.Do()
resp2, _ := req2.Do()
defer resp1.Close()
defer resp2.Close()
fmt.Println(resp1.StatusCode, resp2.StatusCode)
// Output: 200 200
}
func ExampleClient_SetDefaultSkipTLSVerify() {
client := starnet.NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
// All requests from this client will skip TLS verification
// unless overridden at request level
fmt.Println("Client configured")
// Output: Client configured
}
func ExampleWithTimeout() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.Write([]byte("OK"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithTimeout(200*time.Millisecond))
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}

172
file_upload_test.go Normal file
View File

@ -0,0 +1,172 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestRequestAddFileStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20) // 10 MB
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
fileContent := "test file content"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "test.txt:" + fileContent
if body != expected {
t.Errorf("Body = %v; want %v", body, expected)
}
}
func TestRequestAddFileWithFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
// Check form field
name := r.FormValue("name")
if name != "John" {
t.Errorf("name = %v; want John", name)
}
// Check file
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
w.Write([]byte("OK:" + header.Filename))
}))
defer server.Close()
fileContent := "file data"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFileStream("file", "document.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, "document.txt") {
t.Errorf("Body should contain filename, got: %v", body)
}
}
func TestRequestUploadProgress(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
progressCalled := false
var lastUploaded int64
fileContent := strings.Repeat("a", 1024*10) // 10KB
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
SetUploadProgress(func(filename string, uploaded, total int64) {
progressCalled = true
lastUploaded = uploaded
if filename != "test.txt" {
t.Errorf("filename = %v; want test.txt", filename)
}
}).
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if !progressCalled {
t.Error("Progress callback was not called")
}
if lastUploaded != int64(len(fileContent)) {
t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent))
}
}
// TestRequestAddFileFromDisk tests uploading a real file from disk
func TestRequestAddFileFromDisk(t *testing.T) {
// Create a temporary file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
fileContent := []byte("test file content from disk")
err := os.WriteFile(tmpFile, fileContent, 0644)
if err != nil {
t.Fatalf("WriteFile error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, string(fileContent)) {
t.Errorf("Body should contain file content, got: %v", body)
}
}

2
go.mod
View File

@ -1,5 +1,3 @@
module b612.me/starnet
go 1.16
require b612.me/stario v0.0.9

47
go.sum
View File

@ -1,47 +0,0 @@
b612.me/stario v0.0.9 h1:bFDlejUJMwZ12a09snZJspQsOlkqpDAl9qKPEYOGWCk=
b612.me/stario v0.0.9/go.mod h1:x4D/x8zA5SC0pj/uJAi4FyG5p4j5UZoMEZfvuRR6VNw=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

140
header_test.go Normal file
View File

@ -0,0 +1,140 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := make(map[string]string)
for k, v := range r.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
json.NewEncoder(w).Encode(headers)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom-Header", "value1").
AddHeader("X-Multi-Header", "value1").
AddHeader("X-Multi-Header", "value2")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var headers map[string]string
resp.Body().JSON(&headers)
if headers["X-Custom-Header"] != "value1" {
t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"])
}
}
func TestRequestCookies(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookies := make(map[string]string)
for _, cookie := range r.Cookies() {
cookies[cookie.Name] = cookie.Value
}
json.NewEncoder(w).Encode(cookies)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddSimpleCookie("session", "abc123").
AddSimpleCookie("user", "john")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var cookies map[string]string
resp.Body().JSON(&cookies)
if cookies["session"] != "abc123" {
t.Errorf("session cookie = %v; want abc123", cookies["session"])
}
if cookies["user"] != "john" {
t.Errorf("user cookie = %v; want john", cookies["user"])
}
}
func TestRequestUserAgent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.UserAgent()))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetUserAgent("CustomAgent/1.0")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "CustomAgent/1.0" {
t.Errorf("User-Agent = %v; want CustomAgent/1.0", body)
}
}
func TestRequestBearerToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
w.Write([]byte(auth))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBearerToken("mytoken123")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "Bearer mytoken123"
if body != expected {
t.Errorf("Authorization = %v; want %v", body, expected)
}
}
func TestRequestBasicAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Write([]byte(username + ":" + password))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBasicAuth("user", "pass")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "user:pass" {
t.Errorf("BasicAuth = %v; want user:pass", body)
}
}

258
integration_test.go Normal file
View File

@ -0,0 +1,258 @@
package starnet
import (
"os"
"testing"
"time"
)
// 这些测试使用 httpbin.org 作为测试服务
// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用
func skipIfNoIntegration(t *testing.T) {
if os.Getenv("STARNET_INTEGRATION_TEST") != "1" {
t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run")
}
}
func TestIntegrationHTTPBinGet(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/get",
WithQuery("name", "starnet"),
WithQuery("version", "1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
args, ok := result["args"].(map[string]interface{})
if !ok {
t.Fatal("args not found in response")
}
if args["name"] != "starnet" {
t.Errorf("args[name] = %v; want starnet", args["name"])
}
}
func TestIntegrationHTTPBinPost(t *testing.T) {
skipIfNoIntegration(t)
type PostData struct {
Name string `json:"name"`
Email string `json:"email"`
}
data := PostData{
Name: "John Doe",
Email: "john@example.com",
}
resp, err := Post("https://httpbin.org/post", WithJSON(data))
if err != nil {
t.Fatalf("Post() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
jsonData, ok := result["json"].(map[string]interface{})
if !ok {
t.Fatal("json not found in response")
}
if jsonData["name"] != data.Name {
t.Errorf("name = %v; want %v", jsonData["name"], data.Name)
}
}
func TestIntegrationHTTPBinHeaders(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/headers",
WithHeader("X-Custom-Header", "test-value"),
WithUserAgent("Starnet-Test/1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
headers, ok := result["headers"].(map[string]interface{})
if !ok {
t.Fatal("headers not found in response")
}
if headers["X-Custom-Header"] != "test-value" {
t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"])
}
}
func TestIntegrationHTTPBinBasicAuth(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/basic-auth/user/passwd",
WithBasicAuth("user", "passwd"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["authenticated"] != true {
t.Error("authenticated should be true")
}
}
func TestIntegrationHTTPBinDelay(t *testing.T) {
skipIfNoIntegration(t)
// Test timeout
start := time.Now()
_, err := Get("https://httpbin.org/delay/3",
WithTimeout(1*time.Second))
elapsed := time.Since(start)
if err == nil {
t.Error("Expected timeout error, got nil")
}
if elapsed > 2*time.Second {
t.Errorf("Timeout took too long: %v", elapsed)
}
}
func TestIntegrationHTTPBinRedirect(t *testing.T) {
skipIfNoIntegration(t)
// Test with redirect enabled
client := NewClientNoErr()
resp, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode)
}
// Test with redirect disabled
client.DisableRedirect()
resp2, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != 302 {
t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode)
}
}
func TestIntegrationHTTPBinCookies(t *testing.T) {
skipIfNoIntegration(t)
// 创建一个禁用重定向的 Client
client := NewClientNoErr()
client.DisableRedirect()
resp, err := client.Get("https://httpbin.org/cookies/set?name=value")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// 现在应该能获取到 Set-Cookie
cookies := resp.Cookies()
if len(cookies) == 0 {
t.Error("Expected cookies in response")
}
// 验证 cookie
found := false
for _, cookie := range cookies {
if cookie.Name == "name" && cookie.Value == "value" {
found = true
break
}
}
if !found {
t.Error("Expected cookie 'name=value' not found")
}
}
func TestIntegrationHTTPBinUserAgent(t *testing.T) {
skipIfNoIntegration(t)
customUA := "Starnet-Integration-Test/1.0"
resp, err := Get("https://httpbin.org/user-agent",
WithUserAgent(customUA))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["user-agent"] != customUA {
t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA)
}
}
func TestIntegrationHTTPBinGzip(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/gzip")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["gzipped"] != true {
t.Error("Response should be gzipped")
}
}

390
options.go Normal file
View File

@ -0,0 +1,390 @@
package starnet
import (
"context"
"crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"os"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 使用该超时
// timeout = 0: 使用 Client 默认超时
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0
func WithTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return func(r *Request) error {
r.config.Network.Proxy = proxy
return nil
}
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return func(r *Request) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
r.config.DNS.CustomIP = ips
return nil
}
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return func(r *Request) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
r.config.DNS.CustomDNS = dnsServers
return nil
}
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set(key, value)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range headers {
r.config.Headers.Set(k, v)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = []byte(body)
r.config.Body.Reader = nil
return nil
}
}
// WithBodyReader 设置请求体Reader
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
r.config.Body.Bytes = data
r.config.Body.Reader = nil
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData = data
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range data {
r.config.Body.FormData[k] = []string{v}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return nil
}
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return func(r *Request) error {
r.config.AutoCalcContentLength = auto
return nil
}
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return func(r *Request) error {
r.httpReq = httpReq
r.doRaw = true
return nil
}
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return func(r *Request) error {
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return nil
}
}

234
options_test.go Normal file
View File

@ -0,0 +1,234 @@
package starnet
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithJSONOpt(t *testing.T) {
type payload struct {
Name string `json:"name"`
Age int `json:"age"`
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
t.Fatalf("content-type=%s", ct)
}
var p payload
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
t.Fatalf("decode err: %v", err)
}
if p.Name != "alice" || p.Age != 18 {
t.Fatalf("payload mismatch: %+v", p)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileOpt(t *testing.T) {
// temp file + cleanup
f, err := os.CreateTemp("", "starnet-upload-*.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
_, _ = f.WriteString("hello-file")
_ = f.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename == "" || string(b) != "hello-file" {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFile("file", f.Name()))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileStreamOpt(t *testing.T) {
content := "stream-content"
reader := strings.NewReader(content)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("up")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename != "a.txt" || string(b) != content {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithQueryOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("k") != "v" {
t.Fatalf("query mismatch: %v", r.URL.Query())
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithQuery("k", "v"))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithUploadProgressOpt(t *testing.T) {
var called int32
var last int64
content := strings.Repeat("x", 4096)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL,
WithUploadProgress(func(filename string, uploaded, total int64) {
atomic.StoreInt32(&called, 1)
last = uploaded
}),
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
)
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
if atomic.LoadInt32(&called) == 0 {
t.Fatal("progress not called")
}
if last != int64(len(content)) {
t.Fatalf("last uploaded=%d want=%d", last, len(content))
}
}
func TestWithTransportOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithContextOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := Get(s.URL, WithContext(ctx))
if err == nil {
t.Fatal("expected context timeout error")
}
}
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomDNS) != 2 {
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
}
}
func TestWithAddCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
}
}
func TestWithCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 2 {
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
}
}
func TestWithDialFuncOpt(t *testing.T) {
called := int32(0)
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.StoreInt32(&called, 1)
return nil, io.EOF
}
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
if req.config.Network.DialFunc == nil {
t.Fatal("dial func not set")
}
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
if atomic.LoadInt32(&called) == 0 {
t.Fatal("dial func not called")
}
}
func TestWithDialTimeoutOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
if req.config.Network.DialTimeout != 123*time.Millisecond {
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
}
}

50
proxy_test.go Normal file
View File

@ -0,0 +1,50 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestProxy(t *testing.T) {
// Create a proxy server
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Proxy received the request
w.Header().Set("X-Proxied", "true")
w.WriteHeader(http.StatusOK)
w.Write([]byte("proxied"))
}))
defer proxyServer.Close()
// Note: This is a simplified test. Real proxy testing requires more setup
req := NewSimpleRequest("http://example.com", "GET").
SetProxy(proxyServer.URL)
// Just verify the proxy is set in config
if req.config.Network.Proxy != proxyServer.URL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL)
}
}
func TestClientLevelProxy(t *testing.T) {
proxyURL := "http://proxy.example.com:8080"
client := NewClientNoErr(WithProxy(proxyURL))
req, _ := client.NewRequest("http://example.com", "GET")
if req.config.Network.Proxy != proxyURL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL)
}
}
func TestRequestLevelProxyOverride(t *testing.T) {
clientProxy := "http://client-proxy.com:8080"
requestProxy := "http://request-proxy.com:8080"
client := NewClientNoErr(WithProxy(clientProxy))
req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy))
// Request level should override client level
if req.config.Network.Proxy != requestProxy {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy)
}
}

325
que.go
View File

@ -1,325 +0,0 @@
package starnet
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"os"
"sync"
"time"
)
var ErrDeadlineExceeded error = errors.New("deadline exceeded")
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
maxLength uint32
count int64
Encode bool
msgID uint16
msgPool chan MsgQueue
unFinMsg sync.Map
lastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64, maxMsgLength uint32) *StarQueue {
var q StarQueue
q.Encode = false
q.count = count
q.maxLength = maxMsgLength
q.msgPool = make(chan MsgQueue, count)
if ctx == nil {
q.ctx, q.cancel = context.WithCancel(context.Background())
} else {
q.ctx, q.cancel = context.WithCancel(ctx)
}
q.duration = 0
return &q
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count, 0)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (q *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
q.msgID++
if q.Encode {
src = q.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (q *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
q.msgID++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(q.msgID))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (q *StarQueue) push2list(msg MsgQueue) {
q.msgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (q *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return q.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (q *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := q.unFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
q.unFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return q.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
if q.maxLength != 0 && lastMsg.LengthRecv > q.maxLength {
q.unFinMsg.Delete(conn)
return fmt.Errorf("msg length is %d ,too large than %d", lastMsg.LengthRecv, q.maxLength)
}
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
q.unFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
q.unFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//q.restoreMu.Lock()
q.push2list(storeMsg)
//q.restoreMu.Unlock()
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
q.unFinMsg.Delete(conn)
return q.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if q.Encode {
lastMsg.RecvMsg = q.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
q.push2list(storeMsg)
q.unFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return q.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
q.unFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
q.unFinMsg.Store(conn, &lastMsg)
return q.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (q *StarQueue) Restore() (MsgQueue, error) {
if q.duration.Seconds() == 0 {
q.duration = 86400 * time.Second
}
for {
select {
case <-q.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(q.duration):
if q.duration != 0 {
return MsgQueue{}, ErrDeadlineExceeded
}
case data, ok := <-q.msgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
// 兼容性修改
func (q *StarQueue) RestoreOne() (MsgQueue, error) {
return q.Restore()
}
// Stop 立即停止Restore
func (q *StarQueue) Stop() {
q.cancel()
}
// RestoreDuration Restore最大超时时间
func (q *StarQueue) RestoreDuration(tm time.Duration) {
q.duration = tm
}
func (q *StarQueue) RestoreChan() <-chan MsgQueue {
return q.msgPool
}

View File

@ -1,42 +0,0 @@
package starnet
import (
"fmt"
"testing"
"time"
)
func Test_QueSpeed(t *testing.T) {
que := NewQueueWithCount(0)
stop := make(chan struct{}, 1)
que.RestoreDuration(time.Second * 10)
var count int64
go func() {
for {
select {
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select {
case <-stoped:
fmt.Println(count, cp)
stop <- struct{}{}
return
default:
que.ParseMessage(data, "lala")
cp++
}
}
}

98
query_test.go Normal file
View File

@ -0,0 +1,98 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string][]string)
for k, v := range query {
result[k] = v
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("name", "John").
AddQuery("age", "30").
AddQuery("tags", "go").
AddQuery("tags", "http")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string][]string
resp.Body().JSON(&result)
if len(result["name"]) != 1 || result["name"][0] != "John" {
t.Errorf("name = %v; want [John]", result["name"])
}
if len(result["tags"]) != 2 {
t.Errorf("tags length = %v; want 2", len(result["tags"]))
}
}
func TestRequestSetQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
w.Write([]byte(query.Get("key")))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetQuery("key", "value1").
SetQuery("key", "value2") // Should overwrite
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "value2" {
t.Errorf("query value = %v; want value2", body)
}
}
func TestRequestDeleteQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string]string)
for k := range query {
result[k] = query.Get(k)
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("keep", "yes").
AddQuery("delete", "no").
DeleteQuery("delete")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if _, exists := result["delete"]; exists {
t.Error("delete query should not exist")
}
if result["keep"] != "yes" {
t.Errorf("keep = %v; want yes", result["keep"])
}
}

373
request.go Normal file
View File

@ -0,0 +1,373 @@
package starnet
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
)
// Request HTTP 请求
type Request struct {
ctx context.Context
execCtx context.Context // 执行时的 context注入了配置
url string
method string
err error // 累积的错误
config *RequestConfig
client *Client
httpClient *http.Client
httpReq *http.Request
applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改)
autoFetch bool // 是否自动获取响应体
}
// newRequest 创建新请求(内部使用)
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
if method == "" {
method = http.MethodGet
}
method = strings.ToUpper(method)
// 创建 http.Request
httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
if err != nil {
return nil, wrapError(err, "create http request")
}
// 初始化配置
config := &RequestConfig{
Network: NetworkConfig{
DialTimeout: DefaultDialTimeout,
Timeout: DefaultTimeout,
},
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
}
// 设置默认 User-Agent
config.Headers.Set("User-Agent", DefaultUserAgent)
// POST 请求默认 Content-Type
if method == http.MethodPost {
config.Headers.Set("Content-Type", ContentTypeFormURLEncoded)
}
req := &Request{
ctx: ctx,
url: urlStr,
method: method,
config: config,
httpReq: httpReq,
autoFetch: DefaultFetchRespBody,
}
// 应用选项
for _, opt := range opts {
if opt != nil {
if err := opt(req); err != nil {
req.err = err
return req, nil // 不返回错误,累积到 req.err
}
}
}
return req, nil
}
// NewRequest 创建新请求
func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
return nil, err
}
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewRequestWithContext 创建新请求(带 context
func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return nil, err
}
// 新增
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
// 返回一个带错误的请求
return &Request{
ctx: context.Background(),
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// Clone 克隆请求
func (r *Request) Clone() *Request {
cloned := &Request{
ctx: r.ctx,
url: r.url,
method: r.method,
err: r.err,
config: r.config.Clone(),
client: r.client,
httpClient: r.httpClient,
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
}
// 重新创建 http.Request
if !r.doRaw {
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
} else {
cloned.httpReq = r.httpReq
}
return cloned
}
// Err 获取累积的错误
func (r *Request) Err() error {
return r.err
}
// Context 获取 context
func (r *Request) Context() context.Context {
return r.ctx
}
// SetContext 设置 context
func (r *Request) SetContext(ctx context.Context) *Request {
if r.err != nil {
return r
}
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return r
}
// Method 获取 HTTP 方法
func (r *Request) Method() string {
return r.method
}
// SetMethod 设置 HTTP 方法
func (r *Request) SetMethod(method string) *Request {
if r.err != nil {
return r
}
method = strings.ToUpper(method)
if !validMethod(method) {
r.err = wrapError(ErrInvalidMethod, "method: %s", method)
return r
}
r.method = method
r.httpReq.Method = method
return r
}
// URL 获取 URL
func (r *Request) URL() string {
return r.url
}
// SetURL 设置 URL
func (r *Request) SetURL(urlStr string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set URL when using raw request")
return r
}
u, err := url.Parse(urlStr)
if err != nil {
r.err = wrapError(ErrInvalidURL, "url: %s", urlStr)
return r
}
r.url = urlStr
u.Host = removeEmptyPort(u.Host)
r.httpReq.Host = u.Host
r.httpReq.URL = u
// 更新 TLS ServerName
if r.config.TLS.Config != nil {
r.config.TLS.Config.ServerName = u.Hostname()
}
return r
}
// RawRequest 获取底层 http.Request
func (r *Request) RawRequest() *http.Request {
return r.httpReq
}
// SetRawRequest 设置底层 http.Request启用原始模式
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
if r.err != nil {
return r
}
r.httpReq = httpReq
r.doRaw = true
if httpReq == nil {
r.err = fmt.Errorf("httpReq cannot be nil")
return r
}
return r
}
// EnableRawMode 启用原始模式(不修改请求)
func (r *Request) EnableRawMode() *Request {
r.doRaw = true
return r
}
// DisableRawMode 禁用原始模式
func (r *Request) DisableRawMode() *Request {
r.doRaw = false
return r
}
// SetAutoFetch 设置是否自动获取响应体
func (r *Request) SetAutoFetch(auto bool) *Request {
r.autoFetch = auto
return r
}
// HTTPClient 获取底层 http.Client只读
func (r *Request) HTTPClient() (*http.Client, error) {
if r.err != nil {
return nil, r.err
}
if r.httpClient != nil {
return r.httpClient, nil
}
// 如果还没构建,先准备
if err := r.prepare(); err != nil {
return nil, err
}
return r.httpClient, nil
}
// Client 获取关联的 Client只读
func (r *Request) Client() *Client {
return r.client
}
// Do 执行请求
func (r *Request) Do() (*Response, error) {
// 检查累积的错误
if r.err != nil {
return nil, r.err
}
// 准备请求
if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request")
}
// 执行请求
httpResp, err := r.httpClient.Do(r.httpReq)
if err != nil {
return &Response{
Response: &http.Response{},
request: r,
httpClient: r.httpClient,
body: &Body{},
}, wrapError(err, "do request")
}
// 创建响应
resp := &Response{
Response: httpResp,
request: r,
httpClient: r.httpClient,
body: &Body{
raw: httpResp.Body,
},
}
// 自动获取响应体
if r.autoFetch {
resp.body.readAll()
}
return resp, nil
}
// Get 发送 GET 请求
func (r *Request) Get() (*Response, error) {
return r.SetMethod(http.MethodGet).Do()
}
// Post 发送 POST 请求
func (r *Request) Post() (*Response, error) {
return r.SetMethod(http.MethodPost).Do()
}
// Put 发送 PUT 请求
func (r *Request) Put() (*Response, error) {
return r.SetMethod(http.MethodPut).Do()
}
// Delete 发送 DELETE 请求
func (r *Request) Delete() (*Response, error) {
return r.SetMethod(http.MethodDelete).Do()
}

465
request_body.go Normal file
View File

@ -0,0 +1,465 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)
// SetBody 设置请求体(字节)
func (r *Request) SetBody(body []byte) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return r
}
// SetBodyReader 设置请求体Reader
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return r
}
// SetBodyString 设置请求体(字符串)
func (r *Request) SetBodyString(body string) *Request {
return r.SetBody([]byte(body))
}
// SetJSON 设置 JSON 请求体
func (r *Request) SetJSON(v interface{}) *Request {
if r.err != nil {
return r
}
data, err := json.Marshal(v)
if err != nil {
r.err = wrapError(err, "marshal json")
return r
}
return r.SetContentType(ContentTypeJSON).SetBody(data)
}
// SetFormData 设置表单数据(覆盖)
func (r *Request) SetFormData(data map[string][]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData = data
return r
}
// AddFormData 添加表单数据
func (r *Request) AddFormData(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return r
}
// AddFormDataMap 批量添加表单数据
func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range data {
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
}
return r
}
// AddFile 添加文件(从路径)
func (r *Request) AddFile(formName, filePath string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithName 添加文件(指定文件名)
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithType 添加文件(指定 MIME 类型)
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: fileType,
})
return r
}
// AddFileStream 添加文件流
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: fileType,
})
return r
}
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err // ← 失败时不设置 applied
}
}
// 原始模式不修改请求内容
if r.doRaw {
r.applied = true
return nil
}
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
// 注入配置到 context
r.execCtx = injectRequestConfig(r.ctx, r.config)
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
applyTimeoutOverride := func(base *http.Client) *http.Client {
// 没有 base 时兜底
if base == nil {
base = &http.Client{}
}
rt := r.config.Network.Timeout
// 语义:
// rt < 0 : 本次请求禁用超时Timeout = 0
// rt = 0 : 沿用 base.Timeout
// rt > 0 : 本次请求超时覆盖
if rt == 0 {
return base
}
clone := &http.Client{
Transport: base.Transport,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
}
if rt < 0 {
clone.Timeout = 0
} else {
clone.Timeout = rt
}
return clone
}
// 优先使用请求关联的 Client
if r.client != nil {
return applyTimeoutOverride(r.client.HTTPClient()), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
base := &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}
return applyTimeoutOverride(base), nil
}
// 默认全局 client
return applyTimeoutOverride(DefaultHTTPClient()), nil
}

269
request_config.go Normal file
View File

@ -0,0 +1,269 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 使用该超时
// timeout = 0: 使用 Client 默认超时
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0
func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.Timeout = timeout
return r
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.DialTimeout = timeout
return r
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
if r.err != nil {
return r
}
r.config.Network.Proxy = proxy
return r
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
if r.err != nil {
return r
}
r.config.Network.DialFunc = fn
return r
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
if r.err != nil {
return r
}
r.config.TLS.Config = tlsConfig
return r
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
if r.err != nil {
return r
}
r.config.TLS.SkipVerify = skip
return r
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
if r.err != nil {
return r
}
// 验证 IP 格式
for _, ip := range ips {
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
}
r.config.DNS.CustomIP = ips
return r
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return r
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
if r.err != nil {
return r
}
// 验证 DNS 服务器格式
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
}
r.config.DNS.CustomDNS = dnsServers
return r
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return r
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
if r.err != nil {
return r
}
r.config.DNS.LookupFunc = fn
return r
}
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
if r.err != nil {
return r
}
r.config.BasicAuth = [2]string{username, password}
return r
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
if r.err != nil {
return r
}
r.config.ContentLength = length
return r
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
return r
}
r.config.AutoCalcContentLength = auto
return r
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
if r.err != nil {
return r
}
r.config.Transport = transport
r.config.CustomTransport = true
return r
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
if r.err != nil {
return r
}
r.config.UploadProgress = fn
return r
}
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = append(r.config.Queries[key], value)
return r
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = []string{value}
return r
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil {
return r
}
r.config.Queries = queries
return r
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
if r.err != nil {
return r
}
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return r
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
if r.err != nil {
return r
}
delete(r.config.Queries, key)
return r
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
if r.err != nil {
return r
}
values, ok := r.config.Queries[key]
if !ok {
return r
}
newValues := make([]string, 0, len(values))
for _, v := range values {
if v != value {
newValues = append(newValues, v)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
} else {
r.config.Queries[key] = newValues
}
return r
}

180
request_header.go Normal file
View File

@ -0,0 +1,180 @@
package starnet
import (
"net/http"
)
// SetHeader 设置 Header覆盖
func (r *Request) SetHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Set(key, value)
return r
}
// AddHeader 添加 Header
func (r *Request) AddHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Add(key, value)
return r
}
// SetHeaders 设置所有 Headers覆盖
func (r *Request) SetHeaders(headers http.Header) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers = headers
return r
}
// AddHeaders 批量添加 Headers
func (r *Request) AddHeaders(headers map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range headers {
r.config.Headers.Add(k, v)
}
return r
}
// DeleteHeader 删除 Header
func (r *Request) DeleteHeader(key string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Del(key)
return r
}
// GetHeader 获取 Header
func (r *Request) GetHeader(key string) string {
return r.config.Headers.Get(key)
}
// Headers 获取所有 Headers
func (r *Request) Headers() http.Header {
return r.config.Headers
}
// SetContentType 设置 Content-Type
func (r *Request) SetContentType(contentType string) *Request {
return r.SetHeader("Content-Type", contentType)
}
// SetUserAgent 设置 User-Agent
func (r *Request) SetUserAgent(userAgent string) *Request {
return r.SetHeader("User-Agent", userAgent)
}
// SetReferer 设置 Referer
func (r *Request) SetReferer(referer string) *Request {
return r.SetHeader("Referer", referer)
}
// SetBearerToken 设置 Bearer Token
func (r *Request) SetBearerToken(token string) *Request {
return r.SetHeader("Authorization", "Bearer "+token)
}
// AddCookie 添加 Cookie
func (r *Request) AddCookie(cookie *http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = append(r.config.Cookies, cookie)
return r
}
// AddSimpleCookie 添加简单 Cookiepath 为 /
func (r *Request) AddSimpleCookie(name, value string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
// AddCookieKV 添加 Cookie指定 path
func (r *Request) AddCookieKV(name, value, path string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: path,
})
}
// SetCookies 设置所有 Cookies覆盖
func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = cookies
return r
}
// AddCookies 批量添加 Cookies
func (r *Request) AddCookies(cookies map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return r
}
// Cookies 获取所有 Cookies
func (r *Request) Cookies() []*http.Cookie {
return r.config.Cookies
}
// ResetHeaders 重置所有 Headers
func (r *Request) ResetHeaders() *Request {
if r.err != nil {
return r
}
r.config.Headers = make(http.Header)
return r
}
// ResetCookies 重置所有 Cookies
func (r *Request) ResetCookies() *Request {
if r.err != nil {
return r
}
r.config.Cookies = []*http.Cookie{}
return r
}

172
request_test.go Normal file
View File

@ -0,0 +1,172 @@
package starnet
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewSimpleRequest(t *testing.T) {
tests := []struct {
name string
url string
method string
expectErr bool
}{
{
name: "valid GET request",
url: "https://example.com",
method: "GET",
expectErr: false,
},
{
name: "valid POST request",
url: "https://example.com",
method: "POST",
expectErr: false,
},
{
name: "invalid URL",
url: "://invalid",
method: "GET",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.url, tt.method)
if tt.expectErr {
if err == nil && req.Err() == nil {
t.Errorf("NewRequest() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("NewRequest() unexpected error: %v", err)
}
if req.Method() != strings.ToUpper(tt.method) {
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
}
}
})
}
}
func TestRequestMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req := NewSimpleRequest(server.URL, method)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
if method != "HEAD" {
body, _ := resp.Body().String()
if body != method {
t.Errorf("Body = %v; want %v", body, method)
}
}
})
}
}
func TestRequestSetMethod(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetMethod("POST")
if req.Method() != "POST" {
t.Errorf("Method = %v; want POST", req.Method())
}
req.SetMethod("invalid method!")
if req.Err() == nil {
t.Error("SetMethod with invalid method should set error")
}
}
func TestRequestSetURL(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetURL("https://newexample.com")
if req.URL() != "https://newexample.com" {
t.Errorf("URL = %v; want https://newexample.com", req.URL())
}
req2 := NewSimpleRequest("https://example.com", "GET")
req2.SetURL("://invalid")
if req2.Err() == nil {
t.Error("SetURL with invalid URL should set error")
}
}
func TestRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Test") != "value" {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Test", "value")
// 第一次请求
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
resp.Close()
// 克隆请求
cloned := req.Clone()
cloned.SetHeader("X-Extra", "extra")
// 克隆的请求应该也能成功
resp2, err := cloned.Do()
if err != nil {
t.Fatalf("Cloned Do() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusOK {
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
}
}
func TestRequestContext(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

167
response.go Normal file
View File

@ -0,0 +1,167 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"sync"
)
// Response HTTP 响应
type Response struct {
*http.Response
request *Request
httpClient *http.Client
body *Body
}
// Body 响应体
type Body struct {
raw io.ReadCloser
data []byte
consumed bool
mu sync.Mutex
}
// Request 获取原始请求
func (r *Response) Request() *Request {
return r.request
}
// Body 获取响应体
func (r *Response) Body() *Body {
return r.body
}
// Close 关闭响应体
func (r *Response) Close() error {
if r == nil {
return nil
}
if r.body != nil && r.body.raw != nil {
return r.body.raw.Close()
}
return nil
}
// CloseWithClient 关闭响应体并关闭空闲连接
func (r *Response) CloseWithClient() error {
if r == nil {
return nil
}
if r.httpClient != nil {
r.httpClient.CloseIdleConnections()
}
return r.Close()
}
// readAll 读取所有数据
func (b *Body) readAll() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
return nil
}
if b.raw == nil {
b.consumed = true
return nil
}
data, err := io.ReadAll(b.raw)
if err != nil {
return wrapError(err, "read response body")
}
b.data = data
b.consumed = true
b.raw.Close()
return nil
}
// Bytes 获取响应体字节
func (b *Body) Bytes() ([]byte, error) {
if err := b.readAll(); err != nil {
return nil, err
}
return b.data, nil
}
// String 获取响应体字符串
func (b *Body) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
}
return string(data), nil
}
// JSON 解析 JSON 响应
func (b *Body) JSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// Reader 获取 Reader只能调用一次
func (b *Body) Reader() (io.ReadCloser, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
if b.data != nil {
// 已读取,返回缓存数据的 Reader
return io.NopCloser(bytes.NewReader(b.data)), nil
}
return nil, ErrBodyAlreadyConsumed
}
b.consumed = true
return b.raw, nil
}
// IsConsumed 检查是否已消费
func (b *Body) IsConsumed() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.consumed
}
// Close 关闭 Body
func (b *Body) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.raw != nil {
return b.raw.Close()
}
return nil
}
// MustBytes 获取响应体字节(忽略错误,失败返回 nil
func (b *Body) MustBytes() []byte {
data, err := b.Bytes()
if err != nil {
return nil
}
return data
}
// MustString 获取响应体字符串(忽略错误,失败返回空串)
func (b *Body) MustString() string {
s, err := b.String()
if err != nil {
return ""
}
return s
}
// Unmarshal 解析 JSON 响应(兼容旧 API
func (b *Body) Unmarshal(v interface{}) error {
return b.JSON(v)
}

179
response_test.go Normal file
View File

@ -0,0 +1,179 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestResponseBody(t *testing.T) {
testData := "test response data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// Test String()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
// Test multiple reads (should work because body is cached)
body2, err := resp.Body().String()
if err != nil {
t.Fatalf("Second Body().String() error: %v", err)
}
if body2 != testData {
t.Errorf("Second Body = %v; want %v", body2, testData)
}
}
func TestResponseJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"name":"John","age":30}`))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result struct {
Name string `json:"name"`
Age int `json:"age"`
}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("Body().JSON() error: %v", err)
}
if result.Name != "John" {
t.Errorf("Name = %v; want John", result.Name)
}
if result.Age != 30 {
t.Errorf("Age = %v; want 30", result.Age)
}
}
func TestResponseBytes(t *testing.T) {
testData := []byte("binary data")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(testData)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().Bytes()
if err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if string(body) != string(testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseReader(t *testing.T) {
testData := "stream data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
reader, err := resp.Body().Reader()
if err != nil {
t.Fatalf("Body().Reader() error: %v", err)
}
defer reader.Close()
body, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(body) != testData {
t.Errorf("Body = %v; want %v", string(body), testData)
}
}
func TestResponseAutoFetch(t *testing.T) {
testData := "auto fetch data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
// With auto fetch
resp, err := Get(server.URL, WithAutoFetch(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if !resp.Body().IsConsumed() {
t.Error("Body should be consumed with auto fetch")
}
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"Created", http.StatusCreated},
{"BadRequest", http.StatusBadRequest},
{"NotFound", http.StatusNotFound},
{"InternalServerError", http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != tt.statusCode {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode)
}
})
}
}

66
timeout_test.go Normal file
View File

@ -0,0 +1,66 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Should timeout
req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
// Should succeed
req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond)
resp, err := req2.Do()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if resp != nil {
resp.Close()
}
}
func TestRequestDialTimeout(t *testing.T) {
// Use a non-routable IP to test dial timeout
req := NewSimpleRequest("http://192.0.2.1:80", "GET").
SetDialTimeout(100 * time.Millisecond)
start := time.Now()
_, err := req.Do()
elapsed := time.Since(start)
if err == nil {
t.Error("Expected dial timeout error, got nil")
}
// Should timeout within reasonable time (not wait forever)
if elapsed > 2*time.Second {
t.Errorf("Dial timeout took too long: %v", elapsed)
}
}
func TestClientTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr(WithTimeout(100 * time.Millisecond))
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

229
tls_test.go Normal file
View File

@ -0,0 +1,229 @@
package starnet
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestSkipTLSVerify(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail)
req := NewSimpleRequest(server.URL, "GET")
_, err := req.Do()
if err == nil {
t.Error("Expected TLS error without skip verify, got nil")
}
// With skip verify (should succeed)
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
resp, err := req2.Do()
if err != nil {
t.Fatalf("Do() with skip verify error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}
func TestRequestCustomTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
}
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientDefaultTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestLevelTLSOverride(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Client level: skip verify = false
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(false)
// Request level: skip verify = true (should override)
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestTls(t *testing.T) {
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
WithHeader("hello", "world"),
WithContext(context.Background()),
WithBearerToken("ddddddd")).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
}
func TestTLSWithProxyPath(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Log(resp.Status)
}
func TestTLSWithProxyBug(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 关键:使用 WithProxy 触发 needsDynamicTransport
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前会报tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
// 更精准的复现:直接测试有问题的分支
func TestTLSDialWithoutServerName(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
req, err := client.NewRequest("https://www.google.com", "GET",
WithTimeout(10*time.Second),
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
// 最小复现:只要触发 needsDynamicTransport 即可
func TestMinimalTLSBug(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// WithDialTimeout 也会触发动态 transport
req, err := client.NewRequest("https://www.baidu.com", "GET",
WithDialTimeout(5*time.Second),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前必现tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}

55
tlsconfig.go Normal file
View File

@ -0,0 +1,55 @@
package starnet
import (
"crypto/tls"
"net"
"time"
)
// GetConfigForClientFunc selects TLS config by hostname/SNI.
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
// ListenerConfig controls listener behavior.
type ListenerConfig struct {
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
BaseTLSConfig *tls.Config
// GetConfigForClient selects TLS config for a hostname.
GetConfigForClient GetConfigForClientFunc
// AllowNonTLS allows plain TCP fallback.
AllowNonTLS bool
// SniffTimeout bounds protocol sniffing time. 0 means no timeout.
SniffTimeout time.Duration
// MaxClientHelloBytes limits buffered sniff data.
// If <= 0, default 64KiB.
MaxClientHelloBytes int
// Logger is optional.
Logger Logger
}
// DefaultListenerConfig returns a conservative default config.
func DefaultListenerConfig() ListenerConfig {
return ListenerConfig{
AllowNonTLS: false,
SniffTimeout: 5 * time.Second,
MaxClientHelloBytes: 64 * 1024,
}
}
// TLSDefaults returns a TLS config baseline.
// Caller should set Certificates / GetCertificate as needed.
func TLSDefaults() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// DialConfig controls dialing behavior.
type DialConfig struct {
Timeout time.Duration
LocalAddr net.Addr
}

500
tlssniffer.go Normal file
View File

@ -0,0 +1,500 @@
package starnet
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"sync"
"time"
)
// replayConn replays buffered bytes first, then reads from live conn.
type replayConn struct {
reader io.Reader
conn net.Conn
}
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
return &replayConn{
reader: io.MultiReader(buffered, conn),
conn: conn,
}
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c *replayConn) Close() error { return c.conn.Close() }
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// SniffResult describes protocol sniffing result.
type SniffResult struct {
IsTLS bool
Hostname string
Buffer *bytes.Buffer
}
// Sniffer detects protocol and metadata from initial bytes.
type Sniffer interface {
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
}
// TLSSniffer is the default sniffer implementation.
type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
tee := io.TeeReader(limited, &buf)
var hello *tls.ClientHelloInfo
_ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
cp := *ch
hello = &cp
return nil, nil
},
}).Handshake()
peek := buf.Bytes()
isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), peek...)),
}
if hello != nil {
out.Hostname = hello.ServerName
}
return out, nil
}
// readOnlyConn rejects writes/close and reads from a reader.
type readOnlyConn struct {
r io.Reader
raw net.Conn
}
func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) }
func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (c readOnlyConn) Close() error { return nil }
func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() }
func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() }
func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil }
func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil }
func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil }
// Conn wraps net.Conn with lazy protocol initialization.
type Conn struct {
net.Conn
once sync.Once
initErr error
closeOnce sync.Once
isTLS bool
tlsConn *tls.Conn
plainConn net.Conn
hostname string
baseTLSConfig *tls.Config
getConfigForClient GetConfigForClientFunc
allowNonTLS bool
sniffer Sniffer
sniffTimeout time.Duration
maxClientHello int
logger Logger
stats *Stats
skipSniff bool
}
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{
Conn: raw,
plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient,
allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger,
stats: stats,
}
}
func (c *Conn) init() {
c.once.Do(func() {
if c.skipSniff {
return
}
if c.baseTLSConfig == nil && c.getConfigForClient == nil {
c.isTLS = false
return
}
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
}
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Time{})
}
if err != nil {
c.initErr = err
c.failAndClose("sniff failed: %v", err)
return
}
c.isTLS = res.IsTLS
c.hostname = res.Hostname
if c.isTLS {
if c.stats != nil {
c.stats.incTLSDetected()
}
tlsCfg, errCfg := c.selectTLSConfig()
if errCfg != nil {
c.initErr = errCfg
c.failAndClose("tls config select failed: %v", errCfg)
return
}
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
c.tlsConn = tls.Server(rc, tlsCfg)
return
}
if c.stats != nil {
c.stats.incPlainDetected()
}
if !c.allowNonTLS {
c.initErr = ErrNonTLSNotAllowed
c.failAndClose("plain tcp rejected")
return
}
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
})
}
func (c *Conn) failAndClose(format string, v ...interface{}) {
if c.stats != nil {
c.stats.incInitFailures()
}
if c.logger != nil {
c.logger.Printf("starnet: "+format, v...)
}
_ = c.Close()
}
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
if c.getConfigForClient != nil {
cfg, err := c.getConfigForClient(c.hostname)
if err != nil {
return nil, err
}
if cfg != nil {
return cfg, nil
}
}
if c.baseTLSConfig != nil {
return c.baseTLSConfig, nil
}
return nil, ErrNoTLSConfig
}
// Hostname returns sniffed SNI hostname (if any).
func (c *Conn) Hostname() string {
c.init()
return c.hostname
}
func (c *Conn) IsTLS() bool {
c.init()
return c.initErr == nil && c.isTLS
}
func (c *Conn) TLSConn() (*tls.Conn, error) {
c.init()
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS || c.tlsConn == nil {
return nil, ErrNotTLS
}
return c.tlsConn, nil
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.plainConn.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.plainConn.Write(b)
}
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
if c.tlsConn != nil {
err = c.tlsConn.Close()
} else {
err = c.Conn.Close()
}
if c.stats != nil {
c.stats.incClosed()
}
})
return err
}
func (c *Conn) SetDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.plainConn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.plainConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.plainConn.SetWriteDeadline(t)
}
// Listener wraps net.Listener and returns starnet.Conn from Accept.
type Listener struct {
net.Listener
mu sync.RWMutex
cfg ListenerConfig
stats Stats
}
// Listen creates a plain listener config (no TLS detection).
func Listen(network, address string) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
return &Listener{Listener: ln, cfg: cfg}, nil
}
// ListenWithConfig creates a listener with full config.
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenWithListenConfig creates listener using net.ListenConfig.
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenTLS creates TLS listener from cert/key paths.
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = allowNonTLS
cfg.BaseTLSConfig = TLSDefaults()
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
return ListenWithConfig(network, address, cfg)
}
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
out := DefaultListenerConfig()
out.AllowNonTLS = cfg.AllowNonTLS
out.SniffTimeout = cfg.SniffTimeout
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
out.BaseTLSConfig = cfg.BaseTLSConfig
out.GetConfigForClient = cfg.GetConfigForClient
out.Logger = cfg.Logger
if out.MaxClientHelloBytes <= 0 {
out.MaxClientHelloBytes = 64 * 1024
}
return out
}
// SetConfig atomically replaces listener config for new accepted connections.
func (l *Listener) SetConfig(cfg ListenerConfig) {
l.mu.Lock()
l.cfg = normalizeConfig(cfg)
l.mu.Unlock()
}
// Config returns a copy of current config.
func (l *Listener) Config() ListenerConfig {
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return cfg
}
// Stats returns current counters snapshot.
func (l *Listener) Stats() StatsSnapshot {
return l.stats.Snapshot()
}
func (l *Listener) Accept() (net.Conn, error) {
raw, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.stats.incAccepted()
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return newConn(raw, cfg, &l.stats), nil
}
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
type result struct {
c net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := l.Accept()
ch <- result{c: c, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
return r.c, r.err
}
}
// Dial creates a plain TCP starnet.Conn.
func Dial(network, address string) (*Conn, error) {
raw, err := net.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialWithConfig dials with net.Dialer options.
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
d := net.Dialer{
Timeout: dc.Timeout,
LocalAddr: dc.LocalAddr,
}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialTLSWithConfig creates a TLS client connection wrapper.
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
tc := tls.Client(raw, tlsCfg)
return &Conn{
Conn: raw,
plainConn: raw,
isTLS: true,
tlsConn: tc,
hostname: "",
initErr: nil,
allowNonTLS: false,
skipSniff: true,
}, nil
}
// DialTLS creates TLS client conn from cert/key paths.
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := TLSDefaults()
cfg.Certificates = []tls.Certificate{cert}
return DialTLSWithConfig(network, address, cfg, 0)
}
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
if listener == nil {
return nil, ErrNilConn
}
return &Listener{
Listener: listener,
cfg: normalizeConfig(cfg),
}, nil
}

691
tlssniffer_test.go Normal file
View File

@ -0,0 +1,691 @@
package starnet
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"os"
"sync"
"testing"
"time"
)
// ---------- cert helpers ----------
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
t.Fatalf("serial: %v", err)
}
tpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: "starnet-test",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
DNSNames: dnsNames,
}
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("CreateCertificate: %v", err)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return certPEM, keyPEM
}
func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate {
t.Helper()
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("X509KeyPair: %v", err)
}
return cert
}
func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) {
t.Helper()
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
cf, err := os.CreateTemp("", "starnet-cert-*.pem")
if err != nil {
t.Fatalf("CreateTemp cert: %v", err)
}
kf, err := os.CreateTemp("", "starnet-key-*.pem")
if err != nil {
_ = cf.Close()
_ = os.Remove(cf.Name())
t.Fatalf("CreateTemp key: %v", err)
}
if _, err := cf.Write(certPEM); err != nil {
t.Fatalf("write cert: %v", err)
}
if _, err := kf.Write(keyPEM); err != nil {
t.Fatalf("write key: %v", err)
}
_ = cf.Close()
_ = kf.Close()
return cf.Name(), kf.Name(), func() {
_ = os.Remove(cf.Name())
_ = os.Remove(kf.Name())
}
}
// ---------- server helpers ----------
func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) {
t.Helper()
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("ListenWithConfig: %v", err)
}
var wg sync.WaitGroup
stop := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
c, err := ln.Accept()
if err != nil {
select {
case <-stop:
return
default:
return
}
}
go func(conn net.Conn) {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}(c)
}
}()
cleanup := func() {
close(stop)
_ = ln.Close()
wg.Wait()
}
return ln, ln.Addr().String(), cleanup
}
// ---------- tests ----------
func TestListen(t *testing.T) {
ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
msg := []byte("x")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 1)
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
}
func TestListenWithListenConfig(t *testing.T) {
lc := net.ListenConfig{}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("ListenWithListenConfig: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
msg := []byte("ok")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 2)
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
}
func TestListenerSetConfig(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
cfg2 := cfg
cfg2.SniffTimeout = time.Second
ln.SetConfig(cfg2)
got := ln.Config()
if got.SniffTimeout != time.Second {
t.Fatalf("SetConfig not applied")
}
}
func TestPlainAllowed(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
msg := []byte("hello-plain")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg))
}
}
func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
_, _ = c.Write([]byte("plain"))
_ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
b := make([]byte, 1)
_, err = c.Read(b)
if err == nil {
t.Fatalf("expected read error due to non-tls rejection")
}
}
func TestTLSHandshakeAndEcho(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
msg := []byte("hello-tls")
if _, err := tc.Write(msg); err != nil {
t.Fatalf("tls write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(tc, buf); err != nil {
t.Fatalf("tls read: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg))
}
}
func TestDynamicConfigBySNI(t *testing.T) {
certA := genSelfSignedCert(t, "a.local")
certB := genSelfSignedCert(t, "b.local")
base := TLSDefaults()
base.Certificates = []tls.Certificate{certA}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
if host == "b.local" {
b := TLSDefaults()
b.Certificates = []tls.Certificate{certB}
return b, nil
}
return nil, nil
}
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "b.local",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
if !tc.ConnectionState().HandshakeComplete {
t.Fatalf("handshake not complete")
}
}
func TestGetConfigForClientError(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
return nil, errors.New("boom")
}
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
_, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
})
if err == nil {
t.Fatalf("expected tls dial failure due to selector error")
}
}
func TestAcceptContextCancel(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = ln.AcceptContext(ctx)
if err == nil {
t.Fatalf("expected context timeout/cancel")
}
}
func TestListenerStats(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("dial: %v", err)
}
_, _ = c.Write([]byte("x"))
_ = c.Close()
time.Sleep(100 * time.Millisecond)
s := ln.Stats()
if s.Accepted == 0 {
t.Fatalf("expected accepted > 0")
}
}
func TestDialAndDialWithConfig(t *testing.T) {
nl, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer nl.Close()
go func() {
c, err := nl.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c1, err := Dial("tcp", nl.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c1.Close()
msg := []byte("abc")
if _, err := c1.Write(msg); err != nil {
t.Fatalf("c1 write: %v", err)
}
got := make([]byte, 3)
if _, err := io.ReadFull(c1, got); err != nil {
t.Fatalf("c1 read: %v", err)
}
c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second})
if err != nil {
t.Fatalf("DialWithConfig: %v", err)
}
defer c2.Close()
}
func TestListenTLS_FileAPI(t *testing.T) {
certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
defer cleanupFiles()
ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false)
if err != nil {
t.Fatalf("ListenTLS: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
msg := []byte("hi")
if _, err := tc.Write(msg); err != nil {
t.Fatalf("tls write: %v", err)
}
out := make([]byte, 2)
if _, err := io.ReadFull(tc, out); err != nil {
t.Fatalf("tls read: %v", err)
}
}
func TestDialTLSWithConfig(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.BaseTLSConfig = base
cfg.AllowNonTLS = false
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
clientCfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
}
c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second)
if err != nil {
t.Fatalf("DialTLSWithConfig: %v", err)
}
defer c.Close()
if !c.IsTLS() {
t.Fatalf("expected IsTLS true")
}
}
func TestDialTLS_FileAPI(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.BaseTLSConfig = base
cfg.AllowNonTLS = false
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
defer cleanupFiles()
c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile)
if err != nil {
t.Fatalf("DialTLS: %v", err)
}
defer c.Close()
if !c.IsTLS() {
t.Fatalf("expected IsTLS true")
}
}
func TestConnIsTLS_PlainAndTLS(t *testing.T) {
// ---- plain case ----
plainCfg := DefaultListenerConfig()
plainCfg.AllowNonTLS = true
ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg)
if err != nil {
t.Fatalf("listen1: %v", err)
}
defer ln1.Close()
plainDone := make(chan *Conn, 1)
plainErr := make(chan error, 1)
go func() {
nc, err := ln1.Accept()
if err != nil {
plainErr <- err
return
}
sc, ok := nc.(*Conn)
if !ok {
_ = nc.Close()
plainErr <- errors.New("accepted conn is not *Conn")
return
}
plainDone <- sc
// block until client sends one byte, then close
buf := make([]byte, 1)
_, _ = sc.Read(buf)
_ = sc.Close()
}()
c1, err := net.Dial("tcp", ln1.Addr().String())
if err != nil {
t.Fatalf("dial1: %v", err)
}
if _, err := c1.Write([]byte("p")); err != nil {
_ = c1.Close()
t.Fatalf("plain client write: %v", err)
}
_ = c1.Close()
select {
case err := <-plainErr:
t.Fatalf("plain server error: %v", err)
case sc1 := <-plainDone:
if sc1.IsTLS() {
t.Fatalf("plain conn should not be TLS")
}
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting plain side")
}
// ---- tls case ----
cert := genSelfSignedCert(t, "localhost")
tlsBase := TLSDefaults()
tlsBase.Certificates = []tls.Certificate{cert}
tlsCfg := DefaultListenerConfig()
tlsCfg.BaseTLSConfig = tlsBase
tlsCfg.AllowNonTLS = false
ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg)
if err != nil {
t.Fatalf("listen2: %v", err)
}
defer ln2.Close()
tlsDone := make(chan *Conn, 1)
tlsErr := make(chan error, 1)
go func() {
nc, err := ln2.Accept()
if err != nil {
tlsErr <- err
return
}
sc, ok := nc.(*Conn)
if !ok {
_ = nc.Close()
tlsErr <- errors.New("accepted conn is not *Conn")
return
}
tlsDone <- sc
// key point: wait for real data to ensure TLS handshake/path is executed
buf := make([]byte, 1)
_, _ = sc.Read(buf)
_ = sc.Close()
}()
d := &net.Dialer{Timeout: 2 * time.Second}
tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{
InsecureSkipVerify: true, // test only
ServerName: "localhost",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
if _, err := tc.Write([]byte("t")); err != nil {
_ = tc.Close()
t.Fatalf("tls client write: %v", err)
}
_ = tc.Close()
select {
case err := <-tlsErr:
t.Fatalf("tls server error: %v", err)
case sc2 := <-tlsDone:
if !sc2.IsTLS() {
t.Fatalf("tls conn should be TLS")
}
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting tls side")
}
}

43
tlsstats.go Normal file
View File

@ -0,0 +1,43 @@
package starnet
import "sync/atomic"
// StatsSnapshot is a read-only copy of runtime counters.
type StatsSnapshot struct {
Accepted uint64
TLSDetected uint64
PlainDetected uint64
InitFailures uint64
Closed uint64
}
// Stats provides lock-free counters.
type Stats struct {
accepted uint64
tlsDetected uint64
plainDetected uint64
initFailures uint64
closed uint64
}
func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) }
func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
// Snapshot returns a stable view of counters.
func (s *Stats) Snapshot() StatsSnapshot {
return StatsSnapshot{
Accepted: atomic.LoadUint64(&s.accepted),
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
PlainDetected: atomic.LoadUint64(&s.plainDetected),
InitFailures: atomic.LoadUint64(&s.initFailures),
Closed: atomic.LoadUint64(&s.closed),
}
}
// Logger is a minimal logging abstraction.
type Logger interface {
Printf(format string, v ...interface{})
}

97
transport.go Normal file
View File

@ -0,0 +1,97 @@
package starnet
import (
"net/http"
"net/url"
"sync"
"time"
)
// Transport 自定义 Transport支持请求级配置
type Transport struct {
base *http.Transport
mu sync.RWMutex
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// 确保 base 已初始化
if t.base == nil {
t.mu.Lock()
if t.base == nil {
t.base = &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
t.mu.Unlock()
}
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
// 优先级1完全自定义的 transport
if reqCtx.Transport != nil {
return reqCtx.Transport.RoundTrip(req)
}
// 优先级2需要动态配置
if needsDynamicTransport(reqCtx) {
dynamicTransport := t.buildDynamicTransport(reqCtx)
return dynamicTransport.RoundTrip(req)
}
// 优先级3使用基础 transport
t.mu.RLock()
defer t.mu.RUnlock()
return t.base.RoundTrip(req)
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
t.mu.RLock()
transport := t.base.Clone()
t.mu.RUnlock()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
transport.TLSClientConfig = rc.TLSConfig
}
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := url.Parse(rc.Proxy)
if err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
transport.DialContext = rc.DialFn
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
transport.DialTLSContext = defaultDialTLSFunc
}
return transport
}
// Base 获取基础 Transport
func (t *Transport) Base() *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base
}
// SetBase 设置基础 Transport
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.mu.Unlock()
}

131
types.go Normal file
View File

@ -0,0 +1,131 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
)
// HTTP Content-Type 常量
const (
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
ContentTypeFormData = "multipart/form-data"
ContentTypeJSON = "application/json"
ContentTypeXML = "application/xml"
ContentTypePlain = "text/plain"
ContentTypeHTML = "text/html"
ContentTypeOctetStream = "application/octet-stream"
)
// 默认配置
const (
DefaultDialTimeout = 5 * time.Second
DefaultTimeout = 10 * time.Second
DefaultUserAgent = "Starnet/1.0.0"
DefaultFetchRespBody = false
)
// RequestFile 表示要上传的文件
type RequestFile struct {
FormName string // 表单字段名
FileName string // 文件名
FilePath string // 文件路径(如果从文件读取)
FileData io.Reader // 文件数据流
FileSize int64 // 文件大小
FileType string // MIME 类型
}
// UploadProgressFunc 文件上传进度回调函数
type UploadProgressFunc func(filename string, uploaded int64, total int64)
// NetworkConfig 网络配置
type NetworkConfig struct {
Proxy string // 代理地址
DialTimeout time.Duration // 连接超时
Timeout time.Duration // 总超时
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}
// TLSConfig TLS 配置
type TLSConfig struct {
Config *tls.Config // TLS 配置
SkipVerify bool // 跳过证书验证
}
// DNSConfig DNS 配置
type DNSConfig struct {
CustomIP []string // 直接指定 IP最高优先级
CustomDNS []string // 自定义 DNS 服务器
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
}
// BodyConfig 请求体配置
type BodyConfig struct {
Bytes []byte // 原始字节
Reader io.Reader // 数据流
FormData map[string][]string // 表单数据
Files []RequestFile // 文件列表
}
// RequestConfig 请求配置(内部使用)
type RequestConfig struct {
Network NetworkConfig
TLS TLSConfig
DNS DNSConfig
Body BodyConfig
Headers http.Header
Cookies []*http.Cookie
Queries map[string][]string
// 其他配置
BasicAuth [2]string // Basic 认证
ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length
UploadProgress UploadProgressFunc // 上传进度回调
// Transport 配置
CustomTransport bool // 是否使用自定义 Transport
Transport *http.Transport // 自定义 Transport
}
// Clone 克隆配置
func (c *RequestConfig) Clone() *RequestConfig {
return &RequestConfig{
Network: NetworkConfig{
Proxy: c.Network.Proxy,
DialTimeout: c.Network.DialTimeout,
Timeout: c.Network.Timeout,
DialFunc: c.Network.DialFunc,
},
TLS: TLSConfig{
Config: cloneTLSConfig(c.TLS.Config),
SkipVerify: c.TLS.SkipVerify,
},
DNS: DNSConfig{
CustomIP: cloneStringSlice(c.DNS.CustomIP),
CustomDNS: cloneStringSlice(c.DNS.CustomDNS),
LookupFunc: c.DNS.LookupFunc,
},
Body: BodyConfig{
Bytes: cloneBytes(c.Body.Bytes),
Reader: c.Body.Reader, // Reader 不可克隆
FormData: cloneStringMapSlice(c.Body.FormData),
Files: cloneFiles(c.Body.Files),
},
Headers: cloneHeader(c.Headers),
Cookies: cloneCookies(c.Cookies),
Queries: cloneStringMapSlice(c.Queries),
BasicAuth: c.BasicAuth,
ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength,
UploadProgress: c.UploadProgress,
CustomTransport: c.CustomTransport,
Transport: c.Transport, // Transport 共享
}
}
// RequestOpt 请求选项函数
type RequestOpt func(*Request) error

212
utils.go Normal file
View File

@ -0,0 +1,212 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net/http"
"net/url"
"strings"
)
// validMethod 验证 HTTP 方法是否有效
func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// isNotToken 检查字符是否不是 token 字符
func isNotToken(r rune) bool {
return !isTokenRune(r)
}
// isTokenRune 检查字符是否是 token 字符
func isTokenRune(r rune) bool {
i := int(r)
return i < 127 && isTokenTable[i]
}
// isTokenTable token 字符表
var isTokenTable = [127]bool{
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
}
// hasPort 检查地址是否包含端口
func hasPort(s string) bool {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}
// removeEmptyPort 移除空端口
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}
// UrlEncode URL 编码
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
// UrlEncodeRaw URL 编码(空格编码为 %20
func UrlEncodeRaw(str string) string {
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
}
// UrlDecode URL 解码
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
// BuildQuery 构建查询字符串
func BuildQuery(data map[string]string) string {
query := url.Values{}
for k, v := range data {
query.Add(k, v)
}
return query.Encode()
}
// BuildPostForm 构建 POST 表单数据
func BuildPostForm(data map[string]string) []byte {
return []byte(BuildQuery(data))
}
// cloneHeader 克隆 Header
func cloneHeader(h http.Header) http.Header {
if h == nil {
return make(http.Header)
}
newHeader := make(http.Header, len(h))
for k, v := range h {
newHeader[k] = append([]string(nil), v...)
}
return newHeader
}
// cloneCookies 克隆 Cookies
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
if cookies == nil {
return nil
}
newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies {
newCookies[i] = &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: c.SameSite,
Raw: c.Raw,
Unparsed: append([]string(nil), c.Unparsed...),
}
}
return newCookies
}
// cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil {
return make(map[string][]string)
}
newMap := make(map[string][]string, len(m))
for k, v := range m {
newMap[k] = append([]string(nil), v...)
}
return newMap
}
// cloneBytes 克隆字节切片
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
newBytes := make([]byte, len(b))
copy(newBytes, b)
return newBytes
}
// cloneStringSlice 克隆字符串切片
func cloneStringSlice(s []string) []string {
if s == nil {
return nil
}
newSlice := make([]string, len(s))
copy(newSlice, s)
return newSlice
}
// cloneFiles 克隆文件列表
func cloneFiles(files []RequestFile) []RequestFile {
if files == nil {
return nil
}
newFiles := make([]RequestFile, len(files))
copy(newFiles, files)
return newFiles
}
// cloneTLSConfig 克隆 TLS 配置
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return nil
}
return cfg.Clone()
}
// copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if progress == nil {
return io.Copy(dst, src)
}
var written int64
buf := make([]byte, 32*1024) // 32KB buffer
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, err := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
// 同步调用进度回调(不使用 goroutine
progress(filename, written, total)
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
// 最后一次进度回调
progress(filename, written, total)
return written, nil
}
return written, err
}
}
}

284
utils_test.go Normal file
View File

@ -0,0 +1,284 @@
package starnet
import (
"net/http"
"testing"
"time"
)
func TestUrlEncodeRaw(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "basic string with space",
input: "hello world",
expected: "hello%20world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "chinese characters",
input: "你好世界",
expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncodeRaw(tt.input)
if result != tt.expected {
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlEncode(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "space encoded as plus",
input: "hello world",
expected: "hello+world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncode(tt.input)
if result != tt.expected {
t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlDecode(t *testing.T) {
tests := []struct {
name string
input string
expected string
expectErr bool
}{
{
name: "basic decode",
input: "hello%20world",
expected: "hello world",
expectErr: false,
},
{
name: "plus to space",
input: "hello+world",
expected: "hello world",
expectErr: false,
},
{
name: "special characters",
input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60",
expected: "hello world!@#$%^&*()_+-=~`",
expectErr: false,
},
{
name: "invalid encoding",
input: "%zz",
expected: "",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := UrlDecode(tt.input)
if tt.expectErr {
if err == nil {
t.Errorf("UrlDecode(%q) expected error, got nil", tt.input)
}
} else {
if err != nil {
t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err)
}
if result != tt.expected {
t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected)
}
}
})
}
}
func TestBuildQuery(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected string
}{
{
name: "single parameter",
input: map[string]string{
"key": "value",
},
expected: "key=value",
},
{
name: "empty map",
input: map[string]string{},
expected: "",
},
{
name: "nil map",
input: nil,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildQuery(tt.input)
if result != tt.expected {
t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestBuildPostForm(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected []byte
}{
{
name: "basic form",
input: map[string]string{
"key1": "value1",
},
expected: []byte("key1=value1"),
},
{
name: "empty map",
input: map[string]string{},
expected: []byte(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildPostForm(tt.input)
if string(result) != string(tt.expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestValidMethod(t *testing.T) {
tests := []struct {
name string
method string
expected bool
}{
{"GET", "GET", true},
{"POST", "POST", true},
{"PUT", "PUT", true},
{"DELETE", "DELETE", true},
{"PATCH", "PATCH", true},
{"OPTIONS", "OPTIONS", true},
{"HEAD", "HEAD", true},
{"TRACE", "TRACE", true},
{"CONNECT", "CONNECT", true},
{"invalid with space", "GET POST", false},
{"invalid with special char", "GET<>", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validMethod(tt.method)
if result != tt.expected {
t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected)
}
})
}
}
func TestCloneCookies_FullFields(t *testing.T) {
expire := time.Now().Add(2 * time.Hour)
src := []*http.Cookie{
{
Name: "sid",
Value: "abc123",
Path: "/",
Domain: "example.com",
Expires: expire,
RawExpires: expire.UTC().Format(time.RFC1123),
MaxAge: 3600,
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Raw: "sid=abc123; Path=/; HttpOnly",
Unparsed: []string{"Priority=High", "Partitioned"},
},
}
got := cloneCookies(src)
if got == nil || len(got) != 1 {
t.Fatalf("cloneCookies() len=%v; want 1", len(got))
}
// 指针应不同(不是浅拷贝)
if got[0] == src[0] {
t.Fatal("cookie pointer should be different (deep copy expected)")
}
// 字段值应一致
s := src[0]
g := got[0]
if g.Name != s.Name ||
g.Value != s.Value ||
g.Path != s.Path ||
g.Domain != s.Domain ||
!g.Expires.Equal(s.Expires) ||
g.RawExpires != s.RawExpires ||
g.MaxAge != s.MaxAge ||
g.Secure != s.Secure ||
g.HttpOnly != s.HttpOnly ||
g.SameSite != s.SameSite ||
g.Raw != s.Raw {
t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s)
}
// Unparsed 内容一致
if len(g.Unparsed) != len(s.Unparsed) {
t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed))
}
for i := range s.Unparsed {
if g.Unparsed[i] != s.Unparsed[i] {
t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i])
}
}
// 验证 Unparsed 是深拷贝(修改源不影响目标)
src[0].Unparsed[0] = "Modified=Yes"
if got[0].Unparsed[0] == "Modified=Yes" {
t.Fatal("Unparsed should be deep-copied, but was affected by source mutation")
}
}