fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界

- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
This commit is contained in:
兔子 2026-04-19 15:39:51 +08:00
parent 9ac9b65bc5
commit 732e81316c
Signed by: b612
GPG Key ID: 99DD2222B612B612
43 changed files with 5633 additions and 1728 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
.sentrux/ .sentrux/
agent_readme.md agent_readme.md
target.md target.md
agents.md
.codex

View File

@ -1,58 +1,58 @@
# starnet # starnet
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities. `starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。
## Highlights ## 功能概览
- Request-level timeout by context (without mutating shared `http.Client` timeout) - 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config - 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses - 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
- Response body safety guard via max body bytes limit - 响应体大小限制,避免一次性读取过大内容
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`) - 错误分类辅助:`ClassifyError``IsTimeout``IsDNS``IsTLS``IsProxy``IsCanceled`
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios - TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景
- ICMP ping with IPv4/IPv6 target handling and option-based probing API - ICMP Ping支持 IPv4 / IPv6 目标和选项化探测
## Main Features ## 主要能力
### HTTP Client and Request ### HTTP 客户端与请求构建
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods - 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect` - 支持 `Get``Post``Put``Delete``Head``Patch``Options``Trace``Connect`
- Request body helpers: JSON, form data, multipart file upload, stream body - 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
- Header/cookie/query helpers with defensive copy on key setters - Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
- Request cloning for safe reuse in concurrent or variant calls - `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
### Timeout and Retry ### 超时与重试
- Request timeout is applied by context deadline, not global client timeout - 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
- Retry supports: - 重试支持:
- max attempts - 最大尝试次数
- backoff factor/base/max - 基础退避、最大退避和退避因子
- jitter - 抖动比例
- retry status whitelist - 可重试状态码集合
- idempotent-only guard - 仅幂等方法重试
- custom retry-on-error callback - 自定义错误判定函数
- Retry keeps original request pointer in final response for consistency - 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用
### Response Handling ### 响应处理
- `Bytes/String/JSON/Reader` helpers - 提供 `Bytes``String``JSON``Reader` 等响应体读取接口
- optional auto-fetch mode - 支持自动预取响应体
- configurable max response body bytes to prevent oversized reads - 支持按字节数限制响应体读取上限
### Ping Module ### Ping 模块
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable` - 提供 `Ping``PingWithContext``Pingable` 以及兼容函数 `IsIpPingable`
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size - `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数
- explicit error semantics for permission/protocol/timeout/resolve failures - 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义
## Install ## 安装
```bash ```bash
go get b612.me/starnet go get b612.me/starnet
``` ```
## Quick Example ## 快速示例
```go ```go
package main package main
@ -94,13 +94,18 @@ func main() {
} }
``` ```
## Stability Notes ## 行为说明
- Raw ICMP ping may require elevated privileges on some systems. - `NewClient``NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
- Integration tests that rely on external network are environment-dependent. - `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader``WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
## License ## 稳定性说明
This project is licensed under the Apache License 2.0. - 原始 ICMP Ping 在部分系统上需要额外权限。
See [LICENSE](./LICENSE). - 依赖外部网络环境的集成测试结果可能受运行环境影响。
## 许可证
本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。

View File

@ -148,6 +148,33 @@ func BenchmarkRequestCreation(b *testing.B) {
} }
} }
func BenchmarkRequestPrepareDefaultPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET")
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkRequestPrepareDynamicPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET",
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkResponseBodyRead(b *testing.B) { func BenchmarkResponseBodyRead(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test response data")) w.Write([]byte("test response data"))

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
"time"
) )
// Client HTTP 客户端封装 // Client HTTP 客户端封装
@ -19,14 +18,7 @@ type Client struct {
// NewClient 创建新的 Client // NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) { func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport // 创建基础 Transport
baseTransport := &http.Transport{ baseTransport := newBaseHTTPTransport()
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httpClient := &http.Client{ httpClient := &http.Client{
Transport: &Transport{base: baseTransport}, Transport: &Transport{base: baseTransport},
@ -40,6 +32,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
if err != nil { if err != nil {
return nil, wrapError(err, "create client") return nil, wrapError(err, "create client")
} }
if req.err != nil {
return nil, wrapError(req.err, "create client")
}
/* /*
// 如果选项中有自定义配置,应用到 httpClient // 如果选项中有自定义配置,应用到 httpClient
@ -61,7 +56,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
}, nil }, nil
} }
// NewClientNoErr 创建新的 Client忽略错误 // NewClientNoErr 创建新的 Client忽略错误
// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象;
// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。
func NewClientNoErr(opts ...RequestOpt) *Client { func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...) client, _ := NewClient(opts...)
if client == nil { if client == nil {
@ -172,11 +169,13 @@ func (c *Client) Clone() *Client {
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client { func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok { if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock() transport.mu.Lock()
transport.ensureBaseLocked()
if tlsConfig != nil { if tlsConfig != nil {
transport.base.TLSClientConfig = tlsConfig.Clone() transport.base.TLSClientConfig = tlsConfig.Clone()
} else { } else {
transport.base.TLSClientConfig = nil transport.base.TLSClientConfig = nil
} }
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock() transport.mu.Unlock()
} }
return c return c
@ -186,12 +185,14 @@ func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client { func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok { if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock() transport.mu.Lock()
transport.ensureBaseLocked()
if transport.base.TLSClientConfig == nil { if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{} transport.base.TLSClientConfig = &tls.Config{}
} else { } else {
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone() transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
} }
transport.base.TLSClientConfig.InsecureSkipVerify = skip transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock() transport.mu.Unlock()
} }
return c return c
@ -227,6 +228,9 @@ func (c *Client) NewRequestWithContext(ctx context.Context, url, method string,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if req.err != nil {
return nil, req.err
}
req.client = c req.client = c
req.httpClient = c.client req.httpClient = c.client

View File

@ -14,6 +14,8 @@ type contextKey int
const ( const (
ctxKeyTransport contextKey = iota ctxKeyTransport contextKey = iota
ctxKeyTLSConfig ctxKeyTLSConfig
ctxKeyTLSConfigCacheable
ctxKeyTLSServerName
ctxKeyProxy ctxKeyProxy
ctxKeyCustomIP ctxKeyCustomIP
ctxKeyCustomDNS ctxKeyCustomDNS
@ -21,58 +23,95 @@ const (
ctxKeyTimeout ctxKeyTimeout
ctxKeyLookupIP ctxKeyLookupIP
ctxKeyDialFunc ctxKeyDialFunc
ctxKeyRequestContext
) )
// RequestContext 从 context 中提取的请求配置 // RequestContext 从 context 中提取的请求配置
type RequestContext struct { type RequestContext struct {
Transport *http.Transport Transport *http.Transport
TLSConfig *tls.Config TLSConfig *tls.Config
Proxy string TLSConfigCacheable bool
CustomIP []string TLSServerName string
CustomDNS []string Proxy string
DialTimeout time.Duration CustomIP []string
Timeout time.Duration CustomDNS []string
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) DialTimeout time.Duration
DialFn func(ctx context.Context, network, addr string) (net.Conn, error) Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
} }
var emptyRequestContext = &RequestContext{}
// getRequestContext 从 context 中提取请求配置 // getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext { func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{} if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
return rc
}
}
var rc *RequestContext
ensure := func() *RequestContext {
if rc == nil {
rc = &RequestContext{}
}
return rc
}
if v := ctx.Value(ctxKeyTransport); v != nil { if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport) ensure().Transport, _ = v.(*http.Transport)
} }
if v := ctx.Value(ctxKeyTLSConfig); v != nil { if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config) ensure().TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
ensure().TLSConfigCacheable, _ = v.(bool)
}
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
ensure().TLSServerName, _ = v.(string)
} }
if v := ctx.Value(ctxKeyProxy); v != nil { if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string) ensure().Proxy, _ = v.(string)
} }
if v := ctx.Value(ctxKeyCustomIP); v != nil { if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string) ensure().CustomIP, _ = v.([]string)
} }
if v := ctx.Value(ctxKeyCustomDNS); v != nil { if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string) ensure().CustomDNS, _ = v.([]string)
} }
if v := ctx.Value(ctxKeyDialTimeout); v != nil { if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration) ensure().DialTimeout, _ = v.(time.Duration)
} }
if v := ctx.Value(ctxKeyTimeout); v != nil { if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration) ensure().Timeout, _ = v.(time.Duration)
} }
if v := ctx.Value(ctxKeyLookupIP); v != nil { if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
} }
if v := ctx.Value(ctxKeyDialFunc); v != nil { if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
if rc == nil {
return emptyRequestContext
} }
return rc return rc
} }
func cloneRequestContext(rc *RequestContext) *RequestContext {
if rc == nil {
return nil
}
cloned := *rc
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
return &cloned
}
// needsDynamicTransport 判断是否需要动态 Transport // needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool { func needsDynamicTransport(rc *RequestContext) bool {
if rc == nil {
return false
}
return rc.Transport != nil || return rc.Transport != nil ||
rc.TLSConfig != nil || rc.TLSConfig != nil ||
rc.Proxy != "" || rc.Proxy != "" ||
@ -83,63 +122,67 @@ func needsDynamicTransport(rc *RequestContext) bool {
rc.LookupIPFn != nil rc.LookupIPFn != nil
} }
// injectRequestConfig 将请求配置注入到 context func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context { if config == nil {
execCtx := ctx return nil
}
rc := &RequestContext{
DialTimeout: config.Network.DialTimeout,
Timeout: config.Network.Timeout,
}
// 处理 TLS 配置 // 处理 TLS 配置
var tlsConfig *tls.Config var tlsConfig *tls.Config
tlsConfigCacheable := false
if config.TLS.Config != nil { if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone() tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify { } else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
tlsConfig = &tls.Config{ tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"}, NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
} }
tlsConfigCacheable = true
} }
if config.TLS.SkipVerify && tlsConfig != nil {
tlsConfig.InsecureSkipVerify = true
}
if config.TLS.ServerName != "" && tlsConfig != nil {
tlsConfig.ServerName = config.TLS.ServerName
}
if tlsConfig != nil { if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig) rc.TLSConfig = tlsConfig
rc.TLSConfigCacheable = tlsConfigCacheable
}
if config.TLS.ServerName != "" {
rc.TLSServerName = config.TLS.ServerName
} else if defaultTLSServerName != "" {
rc.TLSServerName = defaultTLSServerName
} }
// 注入代理 rc.Proxy = config.Network.Proxy
if config.Network.Proxy != "" { rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy) rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
} rc.LookupIPFn = config.DNS.LookupFunc
rc.DialFn = config.Network.DialFunc
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil { if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport) rc.Transport = config.Transport
} }
return execCtx if !needsDynamicTransport(rc) {
return nil
}
return rc
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
rc := buildRequestContext(config, defaultTLSServerName)
if rc == nil {
return ctx
}
return context.WithValue(ctx, ctxKeyRequestContext, rc)
} }

View File

@ -57,3 +57,55 @@ func TestSetFormDataDefensiveCopy(t *testing.T) {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got) t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
} }
} }
func TestWithBodyDefensiveCopy(t *testing.T) {
body := []byte("hello")
req, err := NewRequest("http://example.com", "POST", WithBody(body))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
body[0] = 'j'
if string(req.config.Body.Bytes) != "hello" {
t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello")
}
}
func TestWithFormDataDefensiveCopy(t *testing.T) {
form := map[string][]string{
"name": []string{"alice"},
}
req, err := NewRequest("http://example.com", "POST", WithFormData(form))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}
func TestSetCustomIPDefensiveCopy(t *testing.T) {
ips := []string{"1.1.1.1", "8.8.8.8"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips)
ips[0] = "9.9.9.9"
if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" {
t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1")
}
}
func TestSetCustomDNSDefensiveCopy(t *testing.T) {
servers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers)
servers[0] = "9.9.9.9"
if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" {
t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8")
}
}

176
dialer.go
View File

@ -9,10 +9,100 @@ import (
"time" "time"
) )
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
if traceState != nil {
traceState.beginManualDNS()
defer traceState.endManualDNS()
traceState.dnsStart(TraceDNSStartInfo{Host: host})
}
ipAddrs, err := lookup()
if traceState != nil {
traceState.dnsDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), ipAddrs...),
Err: err,
})
}
return ipAddrs, err
}
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var addrs []string
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, joinResolvedHostPort(ip, port))
}
return addrs, nil
}
var (
ipAddrs []net.IPAddr
err error
)
if reqCtx.LookupIPFn != nil {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return reqCtx.LookupIPFn(ctx, host)
})
} else if len(reqCtx.CustomDNS) > 0 {
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return resolver.LookupIPAddr(ctx, host)
})
} else {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
})
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
}
return addrs, nil
}
func joinResolvedHostPort(host, port string) string {
if port == "" {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]"
}
return host
}
return net.JoinHostPort(host, port)
}
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS // defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) { func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置 // 提取配置
reqCtx := getRequestContext(ctx) reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
dialTimeout := reqCtx.DialTimeout dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 { if dialTimeout == 0 {
@ -25,52 +115,9 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
return nil, wrapError(err, "split host port") return nil, wrapError(err, "split host port")
} }
// 获取 IP 地址列表 addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
var addrs []string if err != nil {
return nil, err
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
}
} }
// 尝试连接所有地址 // 尝试连接所有地址
@ -103,13 +150,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
// 提取 TLS 配置 // 提取 TLS 配置
reqCtx := getRequestContext(ctx) reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
tlsConfig := reqCtx.TLSConfig tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = &tls.Config{} tlsConfig = &tls.Config{}
} }
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify自动设置 serverName := tlsConfig.ServerName
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { if serverName == "" {
serverName = reqCtx.TLSServerName
}
if serverName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 { if idx := strings.LastIndex(addr, ":"); idx > 0 {
@ -118,8 +169,19 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
host = addr host = addr
} }
} }
serverName = host
}
if serverName != "" && tlsConfig.ServerName != serverName {
tlsConfig = tlsConfig.Clone() // 避免修改原 config tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = host tlsConfig.ServerName = serverName
}
if traceState != nil {
traceState.markCustomTLS()
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
Network: network,
Addr: addr,
ServerName: serverName,
})
} }
// 执行 TLS 握手 // 执行 TLS 握手
@ -130,9 +192,25 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
tlsConn := tls.Client(conn, tlsConfig) tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
Err: err,
})
}
conn.Close() conn.Close()
return nil, wrapError(err, "tls handshake") return nil, wrapError(err, "tls handshake")
} }
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
ConnectionState: tlsConn.ConnectionState(),
})
}
return tlsConn, nil return tlsConn, nil
} }

View File

@ -0,0 +1,144 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/url"
"testing"
)
func BenchmarkDynamicTransportCustomIP(b *testing.B) {
server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"}))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(b, nil)
defer proxy.Close()
targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-user-tls.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func benchmarkTargetURL(tb testing.TB, rawURL, host string) string {
tb.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
tb.Fatalf("url.Parse() error: %v", err)
}
port := parsed.Port()
if port == "" {
switch parsed.Scheme {
case "https":
port = "443"
default:
port = "80"
}
}
return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery)
}
func pathWithQuery(path, rawQuery string) string {
if path == "" {
path = "/"
}
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}

150
host_tls_regression_test.go Normal file
View File

@ -0,0 +1,150 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetURL("https://other.example")
if req.Err() != nil {
t.Fatalf("unexpected request error: %v", req.Err())
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
}
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{InsecureSkipVerify: true}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig == cfg {
t.Fatal("expected injected tls config to be cloned")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
t.Fatalf("unexpected request context injection: %#v", got)
}
rc := getRequestContext(req.execCtx)
if needsDynamicTransport(rc) {
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
}
if rc.TLSServerName != "" {
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
}
}
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected aggregated request context, got %#v", raw)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatal("expected tls config with skip verify")
}
if rc.TLSServerName != "example.com" {
t.Fatalf("default tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "override.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
SetHost("override.example").
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
}
func TestWithHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "option.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
got, err := resp.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer got.Close()
}

230
internal/pingcore/core.go Normal file
View File

@ -0,0 +1,230 @@
package pingcore
import (
"encoding/binary"
"net"
"os"
"sync/atomic"
"time"
)
const icmpHeaderLen = 8
type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type Options struct {
Count int
Timeout time.Duration
Interval time.Duration
Deadline time.Time
PreferIPv4 bool
PreferIPv6 bool
SourceIP net.IP
PayloadSize int
}
type Result struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var identifierSeed uint32
func NextIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&identifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func Payload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for index := 0; index < len(payload); index++ {
payload[index] = byte(index)
}
return payload
}
func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := MarshalPacket(icmp, payload)
icmp.CheckSum = Checksum(buf)
return icmp
}
func Checksum(data []byte) uint16 {
var (
sum uint32
length = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
}
func Marshal(icmp ICMP) []byte {
return MarshalPacket(icmp, nil)
}
func MarshalPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload))
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
}
func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, offset := range CandidateICMPOffsets(packet, family) {
if offset < 0 || offset+icmpHeaderLen > len(packet) {
continue
}
if packet[offset] != expectedType || packet[offset+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq {
continue
}
return true
}
return false
}
func CandidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
version := packet[0] >> 4
if version == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if version == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return DedupOffsets(offsets)
}
func DedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
seen := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, offset := range offsets {
if _, ok := seen[offset]; ok {
continue
}
seen[offset] = struct{}{}
out = append(out, offset)
}
return out
}
func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = err
}
if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = err
}
if len(targets) > 0 {
return OrderTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, nil
}
func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
return ordered
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
return ordered
}

View File

@ -0,0 +1,123 @@
package tlssniffercore
import "crypto/tls"
func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
ApplyServerTLSOverrides(out, selected)
return out
}
func ApplyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
copied := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for name, cert := range src.NameToCertificate {
copied[name] = cert
}
dst.NameToCertificate = copied
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
}

View File

@ -0,0 +1,237 @@
package tlssniffercore
import (
"bytes"
"encoding/binary"
"io"
"net"
)
type ClientHelloMeta struct {
ServerName string
LocalAddr net.Addr
RemoteAddr net.Addr
SupportedProtos []string
SupportedVersions []uint16
CipherSuites []uint16
}
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
}
type Sniffer struct{}
func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
}
if isTLS {
out.ClientHello = meta
}
return out, nil
}
func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
}
header, complete := readTLSRecordHeader(reader, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, ok := readTLSRecordHeader(reader, buf)
if len(nextHeader) < 5 || !ok {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen)
if !bodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(reader, buf, 5)
}
func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) {
if count <= 0 {
return nil, true
}
tmp := make([]byte, count)
readN, err := io.ReadFull(reader, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for index := 0; index+1 < cipherSuitesLen; index += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
}
}

View File

@ -1,405 +0,0 @@
package starnet
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return func(r *Request) error {
r.config.Network.Proxy = proxy
return nil
}
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return func(r *Request) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
r.config.DNS.CustomIP = ips
return nil
}
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return func(r *Request) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
r.config.DNS.CustomDNS = dnsServers
return nil
}
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set(key, value)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range headers {
r.config.Headers.Set(k, v)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = []byte(body)
r.config.Body.Reader = nil
return nil
}
}
// WithBodyReader 设置请求体Reader
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
r.config.Body.Bytes = data
r.config.Body.Reader = nil
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData = data
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range data {
r.config.Body.FormData[k] = []string{v}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return nil
}
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return func(r *Request) error {
r.config.AutoCalcContentLength = auto
return nil
}
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.doRaw = true
return nil
}
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return func(r *Request) error {
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return nil
}
}

112
options_body.go Normal file
View File

@ -0,0 +1,112 @@
package starnet
import (
"encoding/json"
"io"
"os"
)
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, body)
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, []byte(body))
return nil
}
}
// WithBodyReader 设置请求体Reader
// 出于避免重复写的保守策略Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
setReaderBodyConfig(&r.config.Body, reader)
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
setBytesBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, nil)
for key, value := range data {
r.config.Body.FormData[key] = []string{value}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}

132
options_config.go Normal file
View File

@ -0,0 +1,132 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateTimeout(timeout))
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateDialTimeout(timeout))
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return requestOptFromMutation(mutateProxy(proxy))
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return requestOptFromMutation(mutateDialFunc(fn))
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return requestOptFromMutation(mutateTLSConfig(tlsConfig))
}
// WithTLSServerName 设置显式 TLS ServerName/SNI。
func WithTLSServerName(serverName string) RequestOpt {
return requestOptFromMutation(mutateTLSServerName(serverName))
}
// WithTraceHooks 设置请求 trace 回调。
func WithTraceHooks(hooks *TraceHooks) RequestOpt {
return requestOptFromMutation(mutateTraceHooks(hooks))
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return requestOptFromMutation(mutateSkipTLSVerify(skip))
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return requestOptFromMutation(mutateCustomIP(ips))
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return requestOptFromMutation(mutateAddCustomIP(ip))
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return requestOptFromMutation(mutateCustomDNS(dnsServers))
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return requestOptFromMutation(mutateAddCustomDNS(dns))
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return requestOptFromMutation(mutateLookupFunc(fn))
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return requestOptFromMutation(mutateBasicAuth(username, password))
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return requestOptFromMutation(mutateAddQuery(key, value))
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return requestOptFromMutation(mutateAddQueries(queries))
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return requestOptFromMutation(mutateContentLength(length))
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoCalcContentLength(auto))
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return requestOptFromMutation(mutateUploadProgress(fn))
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return requestOptFromMutation(mutateTransport(transport))
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoFetch(auto))
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes))
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return requestOptFromMutation(mutateRawRequest(httpReq))
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return requestOptFromMutation(mutateContext(ctx))
}

99
options_header.go Normal file
View File

@ -0,0 +1,99 @@
package starnet
import "net/http"
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
return nil
}
r.config.Headers.Set(key, value)
return nil
}
}
// WithHost 设置显式 Host 头覆盖。
func WithHost(host string) RequestOpt {
return func(r *Request) error {
setRequestHostConfig(r.config, host)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for key, value := range headers {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
continue
}
r.config.Headers.Set(key, value)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}

221
ping.go
View File

@ -2,14 +2,14 @@ package starnet
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os" "os"
"strings" "strings"
"sync/atomic"
"time" "time"
"b612.me/starnet/internal/pingcore"
) )
const ( const (
@ -18,7 +18,6 @@ const (
icmpTypeEchoRequestV6 = 128 icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129 icmpTypeEchoReplyV6 = 129
icmpHeaderLen = 8
icmpReadBufSz = 1500 icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second defaultPingAttemptTimeout = 2 * time.Second
@ -26,13 +25,7 @@ const (
maxPingPayloadSize = 65499 // 65507 - ICMP header(8) maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
) )
type ICMP struct { type ICMP = pingcore.ICMP
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type pingSocketSpec struct { type pingSocketSpec struct {
network string network string
@ -42,53 +35,20 @@ type pingSocketSpec struct {
} }
// PingOptions controls ping probing behavior. // PingOptions controls ping probing behavior.
type PingOptions struct { type PingOptions = pingcore.Options
Count int // ping attempts for Pingable, default 3
Timeout time.Duration // per-attempt timeout, default 2s
Interval time.Duration // delay between attempts, default 0
Deadline time.Time // overall deadline for Pingable/PingWithContext
PreferIPv4 bool // prefer IPv4 targets
PreferIPv6 bool // prefer IPv6 targets
SourceIP net.IP // optional source IP for raw socket bind
PayloadSize int // ICMP payload bytes, default 0
}
type PingResult struct { type PingResult = pingcore.Result
Duration time.Duration
RecvCount int
RemoteIP string
}
var pingIdentifierSeed uint32
func nextPingIdentifier() uint16 { func nextPingIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff) return pingcore.NextIdentifier()
n := atomic.AddUint32(&pingIdentifierSeed, 1)
return uint16((pid + n) & 0xffff)
} }
func pingPayload(size int) []byte { func pingPayload(size int) []byte {
if size <= 0 { return pingcore.Payload(size)
return nil
}
payload := make([]byte, size)
for i := 0; i < len(payload); i++ {
payload[i] = byte(i)
}
return payload
} }
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{ return pingcore.BuildICMP(seq, identifier, typ, payload)
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, payload)
icmp.CheckSum = checkSum(buf)
return icmp
} }
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) { func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
@ -120,8 +80,8 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
return res, wrapError(err, "ping write request") return res, wrapError(err, "ping write request")
} }
tStart := time.Now() startedAt := time.Now()
deadline := tStart.Add(timeout) deadline := startedAt.Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) { if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d deadline = d
} }
@ -150,108 +110,34 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
} }
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) { if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n res.RecvCount = n
res.Duration = time.Since(tStart) res.Duration = time.Since(startedAt)
return res, nil return res, nil
} }
} }
} }
func checkSum(data []byte) uint16 { func checkSum(data []byte) uint16 {
var ( return pingcore.Checksum(data)
sum uint32
length int = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
} }
func marshalICMP(icmp ICMP) []byte { func marshalICMP(icmp ICMP) []byte {
return marshalICMPPacket(icmp, nil) return pingcore.Marshal(icmp)
} }
func marshalICMPPacket(icmp ICMP, payload []byte) []byte { func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload)) return pingcore.MarshalPacket(icmp, payload)
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
} }
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, off := range candidateICMPOffsets(packet, family) { return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
if off < 0 || off+icmpHeaderLen > len(packet) {
continue
}
if packet[off] != expectedType || packet[off+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
continue
}
return true
}
return false
} }
func candidateICMPOffsets(packet []byte, family int) []int { func candidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0} return pingcore.CandidateICMPOffsets(packet, family)
if len(packet) == 0 {
return offsets
}
ver := packet[0] >> 4
if ver == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return dedupOffsets(offsets)
} }
func dedupOffsets(offsets []int) []int { func dedupOffsets(offsets []int) []int {
if len(offsets) <= 1 { return pingcore.DedupOffsets(offsets)
return offsets
}
m := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, off := range offsets {
if _, ok := m[off]; ok {
continue
}
m[off] = struct{}{}
out = append(out, off)
}
return out
} }
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) { func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
@ -297,70 +183,18 @@ func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
} }
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil { targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
return []*net.IPAddr{{IP: parsed}}, nil if err != nil {
return nil, err
} }
if len(targets) == 0 {
var targets []*net.IPAddr return nil, ErrPingNoResolvedTarget
var err4 error
var err6 error
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = e
} }
return targets, nil
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = e
}
if len(targets) > 0 {
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, ErrPingNoResolvedTarget
} }
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 { return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
return ordered
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
return ordered
} }
func normalizePingDialError(err error) error { func normalizePingDialError(err error) error {
@ -450,7 +284,6 @@ func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOpt
return resp, nil return resp, nil
} }
// 权限问题通常与地址族无关,继续重试意义不大。
if errors.Is(err, ErrPingPermissionDenied) { if errors.Is(err, ErrPingPermissionDenied) {
return res, err return res, err
} }
@ -501,8 +334,8 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
} }
var lastErr error var lastErr error
for i := 0; i < cfg.Count; i++ { for index := 0; index < cfg.Count; index++ {
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg) _, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
if err == nil { if err == nil {
return true, nil return true, nil
} }
@ -512,7 +345,7 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
break break
} }
if i < cfg.Count-1 && cfg.Interval > 0 { if index < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval) timer := time.NewTimer(cfg.Interval)
select { select {
case <-ctx.Done(): case <-ctx.Done():

110
proxy_custom_ip_test.go Normal file
View File

@ -0,0 +1,110 @@
package starnet
import (
"fmt"
"net"
"net/http"
"testing"
)
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
proxyServer := newIPv4ConnectProxyServer(t, nil)
defer proxyServer.Close()
targetHost := "proxy-custom-ip.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
targets := proxyServer.Targets()
if len(targets) != 1 {
t.Fatalf("connect targets=%v; want 1 target", targets)
}
gotConnectTarget := targets[0]
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
if gotConnectTarget != wantConnectTarget {
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
}
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
targetHost := "custom-ip-direct.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}

331
proxy_local_helpers_test.go Normal file
View File

@ -0,0 +1,331 @@
package starnet
import (
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type connectProxyServer struct {
*httptest.Server
mu sync.Mutex
targets []string
}
func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.Start()
return server
}
func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.StartTLS()
return server
}
func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) {
t.Helper()
testT, ok := t.(*testing.T)
if !ok {
t.Fatal("newTrustedIPv4TLSServer requires *testing.T")
}
certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("X509KeyPair: %v", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certPEM) {
t.Fatal("AppendCertsFromPEM returned false")
}
server := httptest.NewUnstartedServer(handler)
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server.Listener = listener
server.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
server.StartTLS()
return server, pool
}
func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string {
t.Helper()
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
if err != nil {
t.Fatalf("split host port: %v", err)
}
return fmt.Sprintf("https://%s:%s", host, port)
}
func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer {
t.Helper()
proxy := &connectProxyServer{}
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
proxy.mu.Lock()
proxy.targets = append(proxy.targets, r.Host)
proxy.mu.Unlock()
targetConn, err := dialTarget(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
return proxy
}
func (p *connectProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
type socks5ProxyServer struct {
ln net.Listener
addr string
dial func(target string) (net.Conn, error)
stopCh chan struct{}
wg sync.WaitGroup
mu sync.Mutex
targets []string
}
func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer {
t.Helper()
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4 socks5: %v", err)
}
proxy := &socks5ProxyServer{
ln: ln,
addr: ln.Addr().String(),
dial: dialTarget,
stopCh: make(chan struct{}),
}
proxy.wg.Add(1)
go func() {
defer proxy.wg.Done()
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-proxy.stopCh:
return
default:
return
}
}
proxy.wg.Add(1)
go func(c net.Conn) {
defer proxy.wg.Done()
proxy.handleConn(t, c)
}(conn)
}
}()
return proxy
}
func (p *socks5ProxyServer) URL() string {
return "socks5://" + p.addr
}
func (p *socks5ProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
func (p *socks5ProxyServer) Close() {
close(p.stopCh)
_ = p.ln.Close()
p.wg.Wait()
}
func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) {
t.Helper()
closeConn := true
defer func() {
if closeConn {
_ = conn.Close()
}
}()
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return
}
if header[0] != 0x05 {
return
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return
}
if _, err := conn.Write([]byte{0x05, 0x00}); err != nil {
return
}
reqHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, reqHeader); err != nil {
return
}
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
_, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
host, err := readSOCKS5Addr(conn, reqHeader[3])
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
portBytes := make([]byte, 2)
if _, err := io.ReadFull(conn, portBytes); err != nil {
return
}
target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes)))
p.mu.Lock()
p.targets = append(p.targets, target)
p.mu.Unlock()
targetConn, err := p.dial(target)
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil {
targetConn.Close()
return
}
closeConn = false
relayProxyConns(conn, targetConn)
}
func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) {
switch atyp {
case 0x01:
buf := make([]byte, 4)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
case 0x03:
var size [1]byte
if _, err := io.ReadFull(r, size[:]); err != nil {
return "", err
}
buf := make([]byte, int(size[0]))
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return string(buf), nil
case 0x04:
buf := make([]byte, 16)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
default:
return "", fmt.Errorf("unsupported atyp: %d", atyp)
}
}
func relayProxyConns(left, right net.Conn) {
var once sync.Once
closeBoth := func() {
_ = left.Close()
_ = right.Close()
}
go func() {
_, _ = io.Copy(left, right)
once.Do(closeBoth)
}()
go func() {
_, _ = io.Copy(right, left)
once.Do(closeBoth)
}()
}

View File

@ -22,14 +22,131 @@ type Request struct {
httpClient *http.Client httpClient *http.Client
httpReq *http.Request httpReq *http.Request
retry *retryPolicy retry *retryPolicy
traceHooks *TraceHooks
traceState *traceState
applied bool // 是否已应用配置 applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改) doRaw bool // 是否使用原始请求(不修改)
autoFetch bool // 是否自动获取响应体 autoFetch bool // 是否自动获取响应体
rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request
rawTemplate *http.Request
}
func normalizeContext(ctx context.Context) context.Context {
if ctx != nil {
return ctx
}
return context.Background()
}
func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) {
if httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
cloned := httpReq.Clone(normalizeContext(ctx))
switch {
case httpReq.Body == nil || httpReq.Body == http.NoBody:
cloned.Body = httpReq.Body
case httpReq.GetBody != nil:
body, err := httpReq.GetBody()
if err != nil {
return cloned, wrapError(err, "clone raw request body")
}
cloned.Body = body
default:
return cloned, fmt.Errorf("cannot clone raw request with non-replayable body")
}
return cloned, nil
}
func (r *Request) rawBaseRequest() *http.Request {
if r == nil {
return nil
}
if r.rawTemplate != nil {
return r.rawTemplate
}
return r.httpReq
}
func (r *Request) invalidatePreparedState() {
if r == nil {
return
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceState = nil
r.httpClient = nil
wasApplied := r.applied
r.applied = false
if !wasApplied || r.doRaw {
return
}
if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil {
r.err = err
}
}
func (r *Request) rebuildPreparedRequestBase() error {
if r == nil || r.doRaw {
return nil
}
ctx := r.ctx
if ctx == nil {
ctx = context.Background()
}
httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil)
if err != nil {
return wrapError(err, "rebuild http request")
}
r.httpReq = httpReq
r.syncRequestHost()
return nil
}
func (r *Request) rebuildRawRequestBase() error {
if r == nil || !r.doRaw {
return nil
}
baseReq := r.rawBaseRequest()
rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx))
if err != nil && baseReq != nil && baseReq == r.httpReq {
r.httpReq = baseReq.WithContext(normalizeContext(r.ctx))
return nil
}
if rawReq != nil {
r.httpReq = rawReq
}
return err
}
func (r *Request) rebuildExecutionRequestBase() error {
if r == nil {
return nil
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceState = nil
r.applied = false
if r.doRaw {
return r.rebuildRawRequestBase()
}
return r.rebuildPreparedRequestBase()
} }
// newRequest 创建新请求(内部使用) // newRequest 创建新请求(内部使用)
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) { func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
ctx = normalizeContext(ctx)
if method == "" { if method == "" {
method = http.MethodGet method = http.MethodGet
} }
@ -133,6 +250,7 @@ func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误 // NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
ctx = normalizeContext(ctx)
req, err := newRequest(ctx, url, method, opts...) req, err := newRequest(ctx, url, method, opts...)
if err != nil { if err != nil {
return &Request{ return &Request{
@ -163,16 +281,24 @@ func (r *Request) Clone() *Request {
client: r.client, client: r.client,
httpClient: r.httpClient, httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry), retry: cloneRetryPolicy(r.retry),
traceHooks: r.traceHooks,
applied: false, // 重置应用状态 applied: false, // 重置应用状态
doRaw: r.doRaw, doRaw: r.doRaw,
autoFetch: r.autoFetch, autoFetch: r.autoFetch,
rawSourceExternal: r.rawSourceExternal,
} }
// 重新创建 http.Request // 重新创建 http.Request
if !r.doRaw { if !r.doRaw {
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil) cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
} else { } else {
cloned.httpReq = r.httpReq rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx)
cloned.rawTemplate = rawTemplate
cloned.httpReq = rawTemplate
if err != nil && cloned.err == nil {
cloned.err = err
}
} }
return cloned return cloned
@ -190,12 +316,7 @@ func (r *Request) Context() context.Context {
// SetContext 设置 context // SetContext 设置 context
func (r *Request) SetContext(ctx context.Context) *Request { func (r *Request) SetContext(ctx context.Context) *Request {
if r.err != nil { return r.applyMutation(mutateContext(ctx))
return r
}
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return r
} }
// Method 获取 HTTP 方法 // Method 获取 HTTP 方法
@ -215,7 +336,13 @@ func (r *Request) SetMethod(method string) *Request {
} }
r.method = method r.method = method
r.httpReq.Method = method if r.httpReq != nil {
r.httpReq.Method = method
}
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate.Method = method
}
r.invalidatePreparedState()
return r return r
} }
@ -243,45 +370,74 @@ func (r *Request) SetURL(urlStr string) *Request {
r.url = urlStr r.url = urlStr
u.Host = removeEmptyPort(u.Host) u.Host = removeEmptyPort(u.Host)
r.httpReq.Host = u.Host
r.httpReq.URL = u r.httpReq.URL = u
r.syncRequestHost()
// 更新 TLS ServerName r.invalidatePreparedState()
if r.config.TLS.Config != nil {
r.config.TLS.Config.ServerName = u.Hostname()
}
return r return r
} }
func (r *Request) effectiveRequestHost() string {
if r == nil {
return ""
}
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil && r.httpReq.URL != nil {
return removeEmptyPort(r.httpReq.URL.Host)
}
if r.url == "" {
return ""
}
u, err := url.Parse(r.url)
if err != nil {
return ""
}
return removeEmptyPort(u.Host)
}
func (r *Request) syncRequestHost() {
if r == nil || r.httpReq == nil {
return
}
r.httpReq.Host = r.effectiveRequestHost()
}
// RawRequest 获取底层 http.Request // RawRequest 获取底层 http.Request
func (r *Request) RawRequest() *http.Request { func (r *Request) RawRequest() *http.Request {
if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied {
return r.rawTemplate
}
return r.httpReq return r.httpReq
} }
// SetRawRequest 设置底层 http.Request启用原始模式 // SetRawRequest 设置底层 http.Request启用原始模式
func (r *Request) SetRawRequest(httpReq *http.Request) *Request { func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
if r.err != nil { return r.applyMutation(mutateRawRequest(httpReq))
return r
}
r.httpReq = httpReq
r.doRaw = true
if httpReq == nil {
r.err = fmt.Errorf("httpReq cannot be nil")
return r
}
return r
} }
// EnableRawMode 启用原始模式(不修改请求) // EnableRawMode 启用原始模式(不修改请求)
func (r *Request) EnableRawMode() *Request { func (r *Request) EnableRawMode() *Request {
if r.doRaw {
return r
}
r.doRaw = true r.doRaw = true
r.invalidatePreparedState()
return r return r
} }
// DisableRawMode 禁用原始模式 // DisableRawMode 禁用原始模式
func (r *Request) DisableRawMode() *Request { func (r *Request) DisableRawMode() *Request {
if !r.doRaw {
return r
}
if r.rawSourceExternal {
r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest")
return r
}
r.doRaw = false r.doRaw = false
r.invalidatePreparedState()
return r return r
} }
@ -329,6 +485,10 @@ func (r *Request) Do() (*Response, error) {
} }
func (r *Request) doOnce() (*Response, error) { func (r *Request) doOnce() (*Response, error) {
if err := r.rebuildExecutionRequestBase(); err != nil {
return nil, wrapError(err, "rebuild execution request")
}
// 准备请求 // 准备请求
if err := r.prepare(); err != nil { if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request") return nil, wrapError(err, "prepare request")

View File

@ -1,16 +1,9 @@
package starnet package starnet
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"mime/multipart"
"net/http"
"net/url"
"os" "os"
"strings"
) )
// SetBody 设置请求体(字节) // SetBody 设置请求体(字节)
@ -21,12 +14,13 @@ func (r *Request) SetBody(body []byte) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Body.Bytes = body setBytesBodyConfig(&r.config.Body, body)
r.config.Body.Reader = nil r.invalidatePreparedState()
return r return r
} }
// SetBodyReader 设置请求体Reader // SetBodyReader 设置请求体Reader
// 出于避免重复写的保守策略Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func (r *Request) SetBodyReader(reader io.Reader) *Request { func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil { if r.err != nil {
return r return r
@ -34,8 +28,8 @@ func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Body.Reader = reader setReaderBodyConfig(&r.config.Body, reader)
r.config.Body.Bytes = nil r.invalidatePreparedState()
return r return r
} }
@ -67,7 +61,8 @@ func (r *Request) SetFormData(data map[string][]string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Body.FormData = cloneStringMapSlice(data) setFormBodyConfig(&r.config.Body, data)
r.invalidatePreparedState()
return r return r
} }
@ -79,7 +74,9 @@ func (r *Request) AddFormData(key, value string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
r.invalidatePreparedState()
return r return r
} }
@ -91,9 +88,11 @@ func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
for k, v := range data { ensureFormMode(&r.config.Body)
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v) for key, value := range data {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
} }
r.invalidatePreparedState()
return r return r
} }
@ -109,6 +108,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
return r return r
} }
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{ r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName, FormName: formName,
FileName: stat.Name(), FileName: stat.Name(),
@ -116,6 +116,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
FileSize: stat.Size(), FileSize: stat.Size(),
FileType: ContentTypeOctetStream, FileType: ContentTypeOctetStream,
}) })
r.invalidatePreparedState()
return r return r
} }
@ -132,6 +133,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
return r return r
} }
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{ r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName, FormName: formName,
FileName: fileName, FileName: fileName,
@ -139,6 +141,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
FileSize: stat.Size(), FileSize: stat.Size(),
FileType: ContentTypeOctetStream, FileType: ContentTypeOctetStream,
}) })
r.invalidatePreparedState()
return r return r
} }
@ -155,6 +158,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
return r return r
} }
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{ r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName, FormName: formName,
FileName: stat.Name(), FileName: stat.Name(),
@ -162,6 +166,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
FileSize: stat.Size(), FileSize: stat.Size(),
FileType: fileType, FileType: fileType,
}) })
r.invalidatePreparedState()
return r return r
} }
@ -177,6 +182,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
return r return r
} }
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{ r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName, FormName: formName,
FileName: fileName, FileName: fileName,
@ -184,6 +190,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
FileSize: size, FileSize: size,
FileType: ContentTypeOctetStream, FileType: ContentTypeOctetStream,
}) })
r.invalidatePreparedState()
return r return r
} }
@ -199,6 +206,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
return r return r
} }
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{ r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName, FormName: formName,
FileName: fileName, FileName: fileName,
@ -206,243 +214,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
FileSize: size, FileSize: size,
FileType: fileType, FileType: fileType,
}) })
r.invalidatePreparedState()
return r return r
} }
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err // ← 失败时不设置 applied
}
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
// 原始模式不修改请求内容
if !r.doRaw {
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
}
execCtx := r.ctx
if !r.doRaw {
// raw 模式下不注入请求级网络配置,只应用 context/超时。
execCtx = injectRequestConfig(execCtx, r.config)
}
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
if r.config.Network.Timeout > 0 {
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
r.execCtx = execCtx
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
// 优先使用请求关联的 Client
if r.client != nil {
return r.client.HTTPClient(), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
// 默认全局 client
return DefaultHTTPClient(), nil
}

View File

@ -1,282 +0,0 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.Timeout = timeout
return r
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.DialTimeout = timeout
return r
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
if r.err != nil {
return r
}
r.config.Network.Proxy = proxy
return r
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
if r.err != nil {
return r
}
r.config.Network.DialFunc = fn
return r
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
if r.err != nil {
return r
}
r.config.TLS.Config = tlsConfig
return r
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
if r.err != nil {
return r
}
r.config.TLS.SkipVerify = skip
return r
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
if r.err != nil {
return r
}
// 验证 IP 格式
for _, ip := range ips {
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
}
r.config.DNS.CustomIP = ips
return r
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return r
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
if r.err != nil {
return r
}
// 验证 DNS 服务器格式
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
}
r.config.DNS.CustomDNS = dnsServers
return r
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return r
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
if r.err != nil {
return r
}
r.config.DNS.LookupFunc = fn
return r
}
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
if r.err != nil {
return r
}
r.config.BasicAuth = [2]string{username, password}
return r
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
if r.err != nil {
return r
}
r.config.ContentLength = length
return r
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
return r
}
r.config.AutoCalcContentLength = auto
return r
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
if r.err != nil {
return r
}
r.config.Transport = transport
r.config.CustomTransport = true
return r
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
if r.err != nil {
return r
}
r.config.UploadProgress = fn
return r
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
if r.err != nil {
return r
}
if maxBytes < 0 {
r.err = fmt.Errorf("max response body bytes must be >= 0")
return r
}
r.config.MaxRespBodyBytes = maxBytes
return r
}
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = append(r.config.Queries[key], value)
return r
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = []string{value}
return r
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil {
return r
}
r.config.Queries = cloneStringMapSlice(queries)
return r
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
if r.err != nil {
return r
}
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return r
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
if r.err != nil {
return r
}
delete(r.config.Queries, key)
return r
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
if r.err != nil {
return r
}
values, ok := r.config.Queries[key]
if !ok {
return r
}
newValues := make([]string, 0, len(values))
for _, v := range values {
if v != value {
newValues = append(newValues, v)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
} else {
r.config.Queries[key] = newValues
}
return r
}

34
request_execution.go Normal file
View File

@ -0,0 +1,34 @@
package starnet
import "net/http"
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
return r.applyMutation(mutateBasicAuth(username, password))
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
return r.applyMutation(mutateContentLength(length))
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
return r.applyMutation(mutateAutoCalcContentLength(auto))
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
return r.applyMutation(mutateTransport(transport))
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
return r.applyMutation(mutateUploadProgress(fn))
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
return r.applyMutation(mutateMaxRespBodyBytes(maxBytes))
}

View File

@ -0,0 +1,172 @@
package starnet
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
)
func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/path", http.MethodPost).
SetHeader("X-Test", "one").
AddQuery("q", "v").
SetBodyReader(strings.NewReader("payload"))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := r.Context().Err(); err != nil {
t.Fatalf("request context already done: %v", err)
}
if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" {
t.Fatalf("header values=%v", values)
}
if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" {
t.Fatalf("query values=%v", values)
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q", string(body))
}
n := atomic.AddInt32(&attempts, 1)
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))),
Request: r,
}, nil
}),
}}
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if err := resp1.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).
SetRawRequest(rawReq).
SetProxy("http://proxy.example:8080").
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected request context, got %#v", raw)
}
if rc.Proxy != "http://proxy.example:8080" {
t.Fatalf("proxy=%q", rc.Proxy)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatalf("tls config=%#v", rc.TLSConfig)
}
if rc.TLSServerName != "override.example" {
t.Fatalf("tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetFormDataOverridesBytesBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetBodyString("stale").
SetFormData(map[string][]string{"k": []string{"v"}})
if req.config.Body.Mode != bodyModeForm {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 {
t.Fatalf("unexpected stale body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
body, err := req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "k=v" {
t.Fatalf("body=%q; want k=v", string(data))
}
}
func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "payload.txt")
if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetJSON(map[string]string{"old": "json-only"}).
AddFile("file", filePath)
if req.config.Body.Mode != bodyModeMultipart {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil {
t.Fatalf("unexpected stale simple body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
data, err := io.ReadAll(req.httpReq.Body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") {
t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type"))
}
if !strings.Contains(string(data), "file-body") {
t.Fatalf("multipart body missing file content: %q", string(data))
}
if strings.Contains(string(data), "json-only") {
t.Fatalf("multipart body still contains stale json: %q", string(data))
}
}

View File

@ -4,6 +4,25 @@ import (
"net/http" "net/http"
) )
func isHostHeaderKey(key string) bool {
return http.CanonicalHeaderKey(key) == "Host"
}
func setRequestHostConfig(config *RequestConfig, host string) {
if config == nil {
return
}
if config.Headers == nil {
config.Headers = make(http.Header)
}
config.Host = host
if host == "" {
config.Headers.Del("Host")
return
}
config.Headers.Set("Host", host)
}
// SetHeader 设置 Header覆盖 // SetHeader 设置 Header覆盖
func (r *Request) SetHeader(key, value string) *Request { func (r *Request) SetHeader(key, value string) *Request {
if r.err != nil { if r.err != nil {
@ -12,7 +31,11 @@ func (r *Request) SetHeader(key, value string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Set(key, value) r.config.Headers.Set(key, value)
r.invalidatePreparedState()
return r return r
} }
@ -24,7 +47,11 @@ func (r *Request) AddHeader(key, value string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Add(key, value) r.config.Headers.Add(key, value)
r.invalidatePreparedState()
return r return r
} }
@ -37,6 +64,9 @@ func (r *Request) SetHeaders(headers http.Header) *Request {
return r return r
} }
r.config.Headers = cloneHeader(headers) r.config.Headers = cloneHeader(headers)
r.config.Host = r.config.Headers.Get("Host")
r.syncRequestHost()
r.invalidatePreparedState()
return r return r
} }
@ -49,8 +79,14 @@ func (r *Request) AddHeaders(headers map[string]string) *Request {
return r return r
} }
for k, v := range headers { for k, v := range headers {
if isHostHeaderKey(k) {
setRequestHostConfig(r.config, v)
continue
}
r.config.Headers.Add(k, v) r.config.Headers.Add(k, v)
} }
r.syncRequestHost()
r.invalidatePreparedState()
return r return r
} }
@ -62,18 +98,56 @@ func (r *Request) DeleteHeader(key string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, "")
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
r.config.Headers.Del(key) r.config.Headers.Del(key)
r.invalidatePreparedState()
return r return r
} }
// GetHeader 获取 Header // GetHeader 获取 Header
func (r *Request) GetHeader(key string) string { func (r *Request) GetHeader(key string) string {
if isHostHeaderKey(key) {
return r.config.Host
}
return r.config.Headers.Get(key) return r.config.Headers.Get(key)
} }
// Headers 获取所有 Headers // Headers 获取所有 Headers
func (r *Request) Headers() http.Header { func (r *Request) Headers() http.Header {
return r.config.Headers if r == nil || r.config == nil {
return make(http.Header)
}
return cloneHeader(r.config.Headers)
}
// SetHost 设置请求 Host 头覆盖。
func (r *Request) SetHost(host string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setRequestHostConfig(r.config, host)
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// Host 获取显式 Host 覆盖。
func (r *Request) Host() string {
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil {
return r.httpReq.Host
}
return ""
} }
// SetContentType 设置 Content-Type // SetContentType 设置 Content-Type
@ -104,7 +178,8 @@ func (r *Request) AddCookie(cookie *http.Cookie) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Cookies = append(r.config.Cookies, cookie) r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie))
r.invalidatePreparedState()
return r return r
} }
@ -134,7 +209,8 @@ func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Cookies = cookies r.config.Cookies = cloneCookies(cookies)
r.invalidatePreparedState()
return r return r
} }
@ -153,12 +229,16 @@ func (r *Request) AddCookies(cookies map[string]string) *Request {
Path: "/", Path: "/",
}) })
} }
r.invalidatePreparedState()
return r return r
} }
// Cookies 获取所有 Cookies // Cookies 获取所有 Cookies
func (r *Request) Cookies() []*http.Cookie { func (r *Request) Cookies() []*http.Cookie {
return r.config.Cookies if r == nil || r.config == nil {
return nil
}
return cloneCookies(r.config.Cookies)
} }
// ResetHeaders 重置所有 Headers // ResetHeaders 重置所有 Headers
@ -167,6 +247,9 @@ func (r *Request) ResetHeaders() *Request {
return r return r
} }
r.config.Headers = make(http.Header) r.config.Headers = make(http.Header)
r.config.Host = ""
r.syncRequestHost()
r.invalidatePreparedState()
return r return r
} }
@ -176,5 +259,6 @@ func (r *Request) ResetCookies() *Request {
return r return r
} }
r.config.Cookies = []*http.Cookie{} r.config.Cookies = []*http.Cookie{}
r.invalidatePreparedState()
return r return r
} }

69
request_multipart.go Normal file
View File

@ -0,0 +1,69 @@
package starnet
import (
"context"
"io"
"mime/multipart"
"os"
)
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody(execCtx context.Context) error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
go func() {
defer pw.Close()
defer writer.Close()
for key, values := range r.config.Body.FormData {
for _, value := range values {
if err := writer.WriteField(key, value); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
for _, file := range r.config.Body.Files {
if err := r.writeFile(execCtx, writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error {
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
_, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}

326
request_mutation.go Normal file
View File

@ -0,0 +1,326 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"time"
)
type requestMutation func(*Request) error
func (r *Request) applyMutation(mutation requestMutation) *Request {
if r == nil || r.err != nil {
return r
}
if err := mutation(r); err != nil {
r.err = err
return r
}
r.invalidatePreparedState()
return r
}
func requestOptFromMutation(mutation requestMutation) RequestOpt {
return func(r *Request) error {
if r == nil {
return nil
}
return mutation(r)
}
}
func validateCustomIPs(ips []string) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
return nil
}
func validateCustomDNS(dnsServers []string) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
return nil
}
func parseProxyURL(proxy string) (*url.URL, error) {
if proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
if proxyURL.Scheme == "" {
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
}
if proxyURL.Host == "" {
return nil, fmt.Errorf("proxy host is required: %s", proxy)
}
return proxyURL, nil
}
func mutateTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
func mutateDialTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
func mutateProxy(proxy string) requestMutation {
return func(r *Request) error {
if _, err := parseProxyURL(proxy); err != nil {
return err
}
r.config.Network.Proxy = proxy
return nil
}
}
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
func mutateTLSServerName(serverName string) requestMutation {
return func(r *Request) error {
r.config.TLS.ServerName = serverName
return nil
}
}
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
return func(r *Request) error {
r.traceHooks = hooks
return nil
}
}
func mutateSkipTLSVerify(skip bool) requestMutation {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
func mutateCustomIP(ips []string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs(ips); err != nil {
return err
}
r.config.DNS.CustomIP = cloneStringSlice(ips)
return nil
}
}
func mutateAddCustomIP(ip string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs([]string{ip}); err != nil {
return err
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
func mutateCustomDNS(dnsServers []string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS(dnsServers); err != nil {
return err
}
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
return nil
}
}
func mutateAddCustomDNS(dns string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS([]string{dns}); err != nil {
return err
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
func mutateBasicAuth(username, password string) requestMutation {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
func mutateContentLength(length int64) requestMutation {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
func mutateAutoCalcContentLength(auto bool) requestMutation {
return func(r *Request) error {
if r.doRaw {
return fmt.Errorf("cannot set auto calc content length in raw mode")
}
r.config.AutoCalcContentLength = auto
return nil
}
}
func mutateTransport(transport *http.Transport) requestMutation {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
func mutateAutoFetch(auto bool) requestMutation {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
func mutateContext(ctx context.Context) requestMutation {
return func(r *Request) error {
ctx = normalizeContext(ctx)
r.ctx = ctx
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate = r.rawTemplate.WithContext(ctx)
}
if r.httpReq != nil {
r.httpReq = r.httpReq.WithContext(ctx)
}
return nil
}
}
func mutateRawRequest(httpReq *http.Request) requestMutation {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.rawTemplate = httpReq
r.ctx = normalizeContext(httpReq.Context())
r.method = httpReq.Method
if httpReq.URL != nil {
r.url = httpReq.URL.String()
}
r.doRaw = true
r.rawSourceExternal = true
return nil
}
}
func mutateAddQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
func mutateSetQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = []string{value}
return nil
}
}
func mutateSetQueries(queries map[string][]string) requestMutation {
return func(r *Request) error {
r.config.Queries = cloneStringMapSlice(queries)
return nil
}
}
func mutateAddQueries(queries map[string]string) requestMutation {
return func(r *Request) error {
for key, value := range queries {
r.config.Queries[key] = append(r.config.Queries[key], value)
}
return nil
}
}
func mutateDeleteQuery(key string) requestMutation {
return func(r *Request) error {
delete(r.config.Queries, key)
return nil
}
}
func mutateDeleteQueryValue(key, value string) requestMutation {
return func(r *Request) error {
values, ok := r.config.Queries[key]
if !ok {
return nil
}
newValues := make([]string, 0, len(values))
for _, item := range values {
if item != value {
newValues = append(newValues, item)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
return nil
}
r.config.Queries[key] = newValues
return nil
}
}

71
request_network.go Normal file
View File

@ -0,0 +1,71 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateTimeout(timeout))
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateDialTimeout(timeout))
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
return r.applyMutation(mutateProxy(proxy))
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
return r.applyMutation(mutateDialFunc(fn))
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
return r.applyMutation(mutateTLSConfig(tlsConfig))
}
// SetTLSServerName 设置显式 TLS ServerName/SNI。
func (r *Request) SetTLSServerName(serverName string) *Request {
return r.applyMutation(mutateTLSServerName(serverName))
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
return r.applyMutation(mutateSkipTLSVerify(skip))
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
return r.applyMutation(mutateCustomIP(ips))
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
return r.applyMutation(mutateAddCustomIP(ip))
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
return r.applyMutation(mutateCustomDNS(dnsServers))
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
return r.applyMutation(mutateAddCustomDNS(dns))
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
return r.applyMutation(mutateLookupFunc(fn))
}

314
request_prepare.go Normal file
View File

@ -0,0 +1,314 @@
package starnet
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
)
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
if httpReq == nil {
return
}
httpReq.Body = io.NopCloser(bytes.NewReader(data))
httpReq.ContentLength = int64(len(data))
httpReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
}
func clearSimpleBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Bytes = nil
body.Reader = nil
}
func resetFormBodyState(body *BodyConfig) {
if body == nil {
return
}
body.FormData = make(map[string][]string)
}
func resetMultipartBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Files = nil
}
func setBytesBodyConfig(body *BodyConfig, data []byte) {
if body == nil {
return
}
body.Mode = bodyModeBytes
body.Bytes = cloneBytes(data)
body.Reader = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
if body == nil {
return
}
body.Mode = bodyModeReader
body.Reader = reader
body.Bytes = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
if body == nil {
return
}
body.Mode = bodyModeForm
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = cloneStringMapSlice(data)
}
func ensureFormMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = make(map[string][]string)
body.Mode = bodyModeForm
}
func ensureMultipartMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
if body.Mode != bodyModeForm {
clearSimpleBodyState(body)
body.FormData = make(map[string][]string)
}
body.Mode = bodyModeMultipart
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
}
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
// applyBody 应用请求体
func (r *Request) applyBody(execCtx context.Context) error {
r.httpReq.Body = nil
r.httpReq.GetBody = nil
r.httpReq.ContentLength = 0
switch r.config.Body.Mode {
case bodyModeReader:
if r.config.Body.Reader == nil {
return nil
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
case *bytes.Reader:
data, err := snapshotBytesReader(reader)
if err != nil {
return wrapError(err, "snapshot bytes reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
case *strings.Reader:
data, err := snapshotStringReader(reader)
if err != nil {
return wrapError(err, "snapshot strings reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
default:
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(reader.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(reader.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(reader.Len())
}
return nil
case bodyModeBytes:
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
return nil
case bodyModeMultipart:
return r.applyMultipartBody(execCtx)
case bodyModeForm:
values := url.Values{}
for key, items := range r.config.Body.FormData {
for _, value := range items {
values.Add(key, value)
}
}
encoded := values.Encode()
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
return nil
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() (err error) {
if r.applied {
return nil
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
execCtx := r.ctx
if execCtx == nil {
execCtx = context.Background()
}
defaultTLSServerName := ""
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
defaultTLSServerName = r.httpReq.URL.Hostname()
}
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
var traceState *traceState
if r.traceHooks != nil {
traceState = newTraceState(r.traceHooks)
execCtx = withTraceState(execCtx, traceState)
if clientTrace := traceState.clientTrace(); clientTrace != nil {
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
}
}
var cancel context.CancelFunc
if r.config.Network.Timeout > 0 {
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
defer func() {
if err != nil && cancel != nil {
cancel()
}
}()
if r.httpClient == nil {
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err
}
}
if !r.doRaw {
if len(r.config.Queries) > 0 {
query := r.httpReq.URL.Query()
for key, values := range r.config.Queries {
for _, value := range values {
query.Add(key, value)
}
}
r.httpReq.URL.RawQuery = query.Encode()
}
for key, values := range r.config.Headers {
if isHostHeaderKey(key) {
continue
}
for _, value := range values {
r.httpReq.Header.Add(key, value)
}
}
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
if err := r.applyBody(execCtx); err != nil {
return err
}
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
setReplayableRequestBodyBytes(r.httpReq, data)
}
r.syncRequestHost()
}
r.execCtx = execCtx
r.traceState = traceState
r.cancel = cancel
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
if r.client != nil {
return r.client.HTTPClient(), nil
}
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
return DefaultHTTPClient(), nil
}

View File

@ -0,0 +1,335 @@
package starnet
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetHeader("X-Test", "one").
SetBodyString("first")
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
Request: r,
}, nil
}),
}}
if _, err := req.HTTPClient(); err != nil {
t.Fatalf("HTTPClient() error: %v", err)
}
req.SetHeader("X-Test", "two").SetBodyString("second")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != "two:second" {
t.Fatalf("body=%q; want %q", body, "two:second")
}
}
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
}
select {
case <-time.After(50 * time.Millisecond):
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
case <-r.Context().Done():
return nil, r.Context().Err()
}
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
_ = resp.Close()
_, err = req.SetTimeout(10 * time.Millisecond).Do()
if err == nil {
t.Fatal("second Do() succeeded; want timeout error")
}
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("second Do() error=%v; want timeout", err)
}
}
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost)
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, pr)
_ = pr.Close()
close(done)
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := req.writeFile(ctx, writer, RequestFile{
FormName: "file",
FileName: "payload.txt",
FileData: strings.NewReader("payload"),
FileSize: int64(len("payload")),
})
_ = writer.Close()
_ = pw.Close()
<-done
if !errors.Is(err, context.Canceled) {
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
}
}
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
}
}
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
tests := []struct {
name string
req *Request
want string
}{
{
name: "bytes",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
want: "payload",
},
{
name: "bytes-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
want: "payload",
},
{
name: "strings-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
want: "payload",
},
{
name: "form-data",
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
want: "k=v",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
if tt.req.httpReq.GetBody == nil {
t.Fatal("GetBody is nil")
}
body, err := tt.req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != tt.want {
t.Fatalf("body=%q; want %q", string(data), tt.want)
}
})
}
}
type replayRoundTripper struct {
attempts int
bodies []string
}
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
rt.attempts++
rt.bodies = append(rt.bodies, string(body))
if rt.attempts == 1 {
return nil, errors.New("first target failed")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err != nil {
t.Fatalf("roundTripResolvedTargets() error: %v", err)
}
defer resp.Body.Close()
if len(rt.bodies) != 2 {
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
}
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
}
}
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err == nil {
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
}
if len(rt.bodies) != 1 {
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
}
if rt.bodies[0] != "payload" {
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
}
}
func TestRetryReplayableReaderBody(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload")).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q; want payload", string(body))
}
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestWithProxyInvalidReturnsError(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewRequest() succeeded; want invalid proxy error")
}
}
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
client := NewClientNoErr()
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
}
}
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
_, err := NewClient(WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewClient() succeeded; want invalid proxy error")
}
}

31
request_query.go Normal file
View File

@ -0,0 +1,31 @@
package starnet
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
return r.applyMutation(mutateAddQuery(key, value))
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
return r.applyMutation(mutateSetQuery(key, value))
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
return r.applyMutation(mutateSetQueries(queries))
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
return r.applyMutation(mutateAddQueries(queries))
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
return r.applyMutation(mutateDeleteQuery(key))
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
return r.applyMutation(mutateDeleteQueryValue(key, value))
}

View File

@ -0,0 +1,168 @@
package starnet
import (
"io"
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
)
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestSetContextNilUsesBackground(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("request context is nil")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.SetContext(nil).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if req.Context() == nil {
t.Fatal("request Context() is nil")
}
}
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
var hits int32
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("retry request context is nil")
}
if atomic.AddInt32(&hits, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.
SetTimeout(DefaultTimeout).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("hits=%d; want 2", got)
}
}
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
rawReq.Header.Set("X-Test", "one")
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() != nil {
t.Fatalf("Clone() err = %v", cloned.Err())
}
if cloned.RawRequest() == rawReq {
t.Fatal("raw request pointer reused")
}
cloned.RawRequest().Header.Set("X-Test", "two")
if rawReq.Header.Get("X-Test") != "one" {
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
}
body, err := cloned.RawRequest().GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "payload" {
t.Fatalf("body=%q; want payload", string(data))
}
}
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
rawReq := &http.Request{
Method: http.MethodPost,
URL: mustParseURL(t, "http://example.com/upload"),
Header: make(http.Header),
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
}
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() == nil {
t.Fatal("Clone() should fail for non-replayable raw body")
}
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
}
}
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
if req.Err() == nil {
t.Fatal("DisableRawMode() should set error")
}
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
t.Fatalf("DisableRawMode() err=%v", req.Err())
}
if !req.doRaw {
t.Fatal("request should remain in raw mode")
}
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
parsed, err := url.Parse(raw)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
return parsed
}

6
request_trace.go Normal file
View File

@ -0,0 +1,6 @@
package starnet
// SetTraceHooks 设置请求 trace 回调。
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
return r.applyMutation(mutateTraceHooks(hooks))
}

View File

@ -1,6 +1,7 @@
package starnet package starnet
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -9,6 +10,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
) )
@ -87,6 +89,9 @@ func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
return policy, nil return policy, nil
} }
// WithRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func WithRetry(max int, opts ...RetryOpt) RequestOpt { func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error { return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...) policy, err := buildRetryPolicy(max, opts...)
@ -98,6 +103,9 @@ func WithRetry(max int, opts ...RetryOpt) RequestOpt {
} }
} }
// SetRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request { func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil { if r.err != nil {
return r return r
@ -226,10 +234,10 @@ func (r *Request) doWithRetry() (*Response, error) {
return r.doOnce() return r.doOnce()
} }
retryCtx := r.ctx retryCtx := normalizeContext(r.ctx)
retryCancel := func() {} retryCancel := func() {}
if r.config.Network.Timeout > 0 { if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout) retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
} }
defer retryCancel() defer retryCancel()
@ -238,6 +246,12 @@ func (r *Request) doWithRetry() (*Response, error) {
var lastErr error var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ { for attempt := 0; attempt < maxAttempts; attempt++ {
attemptNo := attempt + 1
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
})
attemptReq, err := r.newRetryAttempt(retryCtx) attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil { if err != nil {
return nil, wrapError(err, "build retry attempt") return nil, wrapError(err, "build retry attempt")
@ -248,7 +262,19 @@ func (r *Request) doWithRetry() (*Response, error) {
resp.request = r resp.request = r
} }
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) { willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
StatusCode: statusCode,
Err: err,
WillRetry: willRetry,
})
if !willRetry {
return resp, err return resp, err
} }
@ -262,6 +288,10 @@ func (r *Request) doWithRetry() (*Response, error) {
if delay <= 0 { if delay <= 0 {
continue continue
} }
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
Attempt: attemptNo,
Delay: delay,
})
timer := time.NewTimer(delay) timer := time.NewTimer(delay)
select { select {
@ -293,19 +323,9 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
return attempt, nil return attempt, nil
} }
if r.httpReq == nil { raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
return nil, fmt.Errorf("http request is nil") if err != nil {
} return nil, err
raw := r.httpReq.Clone(ctx)
if r.httpReq.GetBody != nil {
body, err := r.httpReq.GetBody()
if err != nil {
return nil, wrapError(err, "get raw request body")
}
raw.Body = body
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
return nil, fmt.Errorf("raw request body is not replayable")
} }
attempt.httpReq = raw attempt.httpReq = raw
@ -316,6 +336,9 @@ func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) { if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false return false
} }
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r) return isReplayableRequest(r)
} }
@ -347,20 +370,40 @@ func isReplayableRequest(r *Request) bool {
return false return false
} }
// Reader / stream body 通常不可重放,保守地不重试。 return isReplayableConfiguredBody(r.config.Body)
if r.config.Body.Reader != nil { }
func hasReaderRequestBody(r *Request) bool {
if r == nil || r.config == nil {
return false return false
} }
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
}
for _, f := range r.config.Body.Files { func isReplayableConfiguredBody(body BodyConfig) bool {
if f.FileData != nil || f.FilePath == "" { switch body.Mode {
return false case bodyModeReader:
return isReplayableBodyReader(body.Reader)
case bodyModeMultipart:
for _, file := range body.Files {
if file.FileData != nil || file.FilePath == "" {
return false
}
} }
} }
return true return true
} }
func isReplayableBodyReader(reader io.Reader) bool {
switch reader.(type) {
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
return true
default:
return false
}
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool { func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 { if attempt >= maxAttempts-1 {
return false return false

244
review_regression_test.go Normal file
View File

@ -0,0 +1,244 @@
package starnet
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
firstTarget := net.JoinHostPort("127.0.0.2", port)
secondTarget := net.JoinHostPort("127.0.0.1", port)
var (
mu sync.Mutex
connectTargets []string
)
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
mu.Lock()
connectTargets = append(connectTargets, r.Host)
mu.Unlock()
if r.Host == firstTarget {
http.Error(w, "first target failed", http.StatusBadGateway)
return
}
targetConn, err := net.Dial("tcp", r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
defer proxyServer.Close()
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if len(connectTargets) != 2 {
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
}
if connectTargets[0] != firstTarget {
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
}
if connectTargets[1] != secondTarget {
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
}
}
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var (
mu sync.Mutex
dnsStartCount int
dnsDoneCount int
lastHost string
)
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
lastHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if lastHost != "localhost" {
t.Fatalf("lastHost=%q; want localhost", lastHost)
}
}
func TestRequestHeadersReturnsCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetHeader("X-Test", "one").
SetHost("origin.example")
headers := req.Headers()
headers.Set("X-Test", "two")
headers.Set("Host", "mutated.example")
if got := req.GetHeader("X-Test"); got != "one" {
t.Fatalf("request header=%q; want one", got)
}
if got := req.Host(); got != "origin.example" {
t.Fatalf("request host=%q; want origin.example", got)
}
}
func TestRequestCookiesIsolation(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
source := []*http.Cookie{{
Name: "session",
Value: "one",
Path: "/",
}}
req.SetCookies(source)
source[0].Value = "mutated-outside"
got := req.Cookies()
if len(got) != 1 || got[0].Value != "one" {
t.Fatalf("cookies after SetCookies=%v", got)
}
got[0].Value = "mutated-copy"
if latest := req.Cookies()[0].Value; latest != "one" {
t.Fatalf("internal cookie mutated via getter, got %q", latest)
}
cookie := &http.Cookie{Name: "auth", Value: "token"}
req.ResetCookies().AddCookie(cookie)
cookie.Value = "changed"
latest := req.Cookies()
if len(latest) != 1 || latest[0].Value != "token" {
t.Fatalf("cookies after AddCookie=%v", latest)
}
}
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var dnsStartCount int
var dnsDoneCount int
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
dnsStartCount++
},
DNSDone: func(info TraceDNSDoneInfo) {
dnsDoneCount++
},
}
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if dnsStartCount != 1 || dnsDoneCount != 1 {
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
}
}

View File

@ -104,7 +104,34 @@ func TestRequestLevelTLSOverride(t *testing.T) {
} }
func TestRequestTls(t *testing.T) { func TestRequestTls(t *testing.T) {
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do() var requestCount int
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
switch requestCount {
case 1:
if r.Header.Get("Hello") != "" {
t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello"))
}
if auth := r.Header.Get("Authorization"); auth != "" {
t.Fatalf("unexpected authorization on first request: %q", auth)
}
case 2:
if got := r.Header.Get("Hello"); got != "world" {
t.Fatalf("hello header=%q; want world", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" {
t.Fatalf("authorization=%q; want bearer token", got)
}
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
localURL := httpsURLForHost(t, server, "localhost")
resp, err := NewSimpleRequest(localURL, "GET").
SetTLSConfig(&tls.Config{RootCAs: pool}).
Do()
if err != nil { if err != nil {
t.Fatalf("Do() error: %v", err) t.Fatalf("Do() error: %v", err)
} }
@ -114,11 +141,13 @@ func TestRequestTls(t *testing.T) {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
} }
t.Logf("Response: %v", resp.Body().MustString()) t.Logf("Response: %v", resp.Body().MustString())
client, err := NewClient() client, err := NewClient()
if err != nil { if err != nil {
t.Fatalf("NewClient() error: %v", err) t.Fatalf("NewClient() error: %v", err)
} }
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET", resp, err = client.NewSimpleRequest(localURL, "GET",
WithTLSConfig(&tls.Config{RootCAs: pool}),
WithHeader("hello", "world"), WithHeader("hello", "world"),
WithContext(context.Background()), WithContext(context.Background()),
WithBearerToken("ddddddd")).Do() WithBearerToken("ddddddd")).Do()
@ -134,14 +163,24 @@ func TestRequestTls(t *testing.T) {
} }
func TestTLSWithProxyPath(t *testing.T) { func TestTLSWithProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("proxied"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient() client, err := NewClient()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithTimeout(10*time.Second), WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"), WithProxy(proxy.URL),
WithTLSConfig(&tls.Config{RootCAs: pool}),
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -152,10 +191,22 @@ func TestTLSWithProxyPath(t *testing.T) {
t.Fatalf("Do error: %v", err) t.Fatalf("Do error: %v", err)
} }
defer resp.Close() defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 {
t.Fatalf("proxy targets=%v; want 1 target", targets)
}
t.Log(resp.Status) t.Log(resp.Status)
} }
func TestTLSWithProxyBug(t *testing.T) { func TestTLSWithProxyBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient() client, err := NewClient()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -163,9 +214,11 @@ func TestTLSWithProxyBug(t *testing.T) {
// 关键:使用 WithProxy 触发 needsDynamicTransport // 关键:使用 WithProxy 触发 needsDynamicTransport
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支 // 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET",
WithTimeout(10*time.Second), WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -177,20 +230,30 @@ func TestTLSWithProxyBug(t *testing.T) {
t.Fatalf("Do error: %v", err) t.Fatalf("Do error: %v", err)
} }
defer resp.Close() defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("proxy targets=%v", targets)
}
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
// 更精准的复现:直接测试有问题的分支 // 更精准的复现:直接测试有问题的分支
func TestTLSDialWithoutServerName(t *testing.T) { func TestTLSDialWithoutServerName(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient() client, err := NewClient()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc // 使用 WithCustomIP 也能触发 defaultDialTLSFunc
req, err := client.NewRequest("https://www.google.com", "GET", req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET",
WithTimeout(10*time.Second), WithTimeout(10*time.Second),
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -206,14 +269,21 @@ func TestTLSDialWithoutServerName(t *testing.T) {
// 最小复现:只要触发 needsDynamicTransport 即可 // 最小复现:只要触发 needsDynamicTransport 即可
func TestMinimalTLSBug(t *testing.T) { func TestMinimalTLSBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient() client, err := NewClient()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// WithDialTimeout 也会触发动态 transport // WithDialTimeout 也会触发动态 transport
req, err := client.NewRequest("https://www.baidu.com", "GET", req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithDialTimeout(5*time.Second), WithDialTimeout(5*time.Second),
WithTLSConfig(&tls.Config{RootCAs: pool}),
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -227,3 +297,40 @@ func TestMinimalTLSBug(t *testing.T) {
defer resp.Close() defer resp.Close()
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
func TestTLSWithSOCKS5ProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newSOCKS5ProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET",
WithTimeout(10*time.Second),
WithProxy(proxy.URL()),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("socks5 targets=%v", targets)
}
t.Logf("Status: %s", resp.Status)
}

View File

@ -4,12 +4,13 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/binary"
"errors" "errors"
"io" "io"
"net" "net"
"sync" "sync"
"time" "time"
"b612.me/starnet/internal/tlssniffercore"
) )
// replayConn replays buffered bytes first, then reads from live conn. // replayConn replays buffered bytes first, then reads from live conn.
@ -51,214 +52,35 @@ type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible. // Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 { res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
maxBytes = 64 * 1024 if err != nil {
return SniffResult{}, err
} }
return convertCoreSniffResult(res), nil
}
var buf bytes.Buffer func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
out := SniffResult{ out := SniffResult{
IsTLS: isTLS, IsTLS: res.IsTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)), Buffer: res.Buffer,
} }
if isTLS { if res.ClientHello != nil {
out.ClientHello = meta out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
} }
return out, nil return out
} }
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) { func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
meta := &ClientHelloMeta{ if meta == nil {
LocalAddr: conn.LocalAddr(), return nil
RemoteAddr: conn.RemoteAddr(),
} }
return &ClientHelloMeta{
header, complete := readTLSRecordHeader(r, buf) ServerName: meta.ServerName,
if len(header) < 3 { LocalAddr: meta.LocalAddr,
return nil, false RemoteAddr: meta.RemoteAddr,
} SupportedProtos: append([]string(nil), meta.SupportedProtos...),
isTLS := header[0] == 0x16 && header[1] == 0x03 SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
if !isTLS { CipherSuites: append([]uint16(nil), meta.CipherSuites...),
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, nextOK := readTLSRecordHeader(r, buf)
if len(nextHeader) < 5 || !nextOK {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
if !nextBodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(r, buf, 5)
}
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
if n <= 0 {
return nil, true
}
tmp := make([]byte, n)
readN, err := io.ReadFull(r, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for i := 0; i+1 < cipherSuitesLen; i += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
} }
} }
@ -290,17 +112,17 @@ type Conn struct {
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{ return &Conn{
Conn: raw, Conn: raw,
plainConn: raw, plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig, baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient, getConfigForClient: cfg.GetConfigForClient,
getConfigForClientHello: cfg.GetConfigForClientHello, getConfigForClientHello: cfg.GetConfigForClientHello,
allowNonTLS: cfg.AllowNonTLS, allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{}, sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout, sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes, maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger, logger: cfg.Logger,
stats: stats, stats: stats,
} }
} }
@ -433,123 +255,11 @@ func (c *Conn) serverName() string {
} }
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config { func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil { return tlssniffercore.ComposeServerTLSConfig(base, selected)
return selected
}
if selected == nil {
return base
}
out := base.Clone()
applyServerTLSOverrides(out, selected)
return out
} }
func applyServerTLSOverrides(dst, src *tls.Config) { func applyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil { tlssniffercore.ApplyServerTLSOverrides(dst, src)
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for k, v := range src.NameToCertificate {
m[k] = v
}
dst.NameToCertificate = m
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
} }
func (c *Conn) IsTLS() bool { func (c *Conn) IsTLS() bool {

340
trace.go Normal file
View File

@ -0,0 +1,340 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http/httptrace"
"sync/atomic"
"time"
)
type traceContextKey struct{}
// TraceHooks defines optional callbacks for network lifecycle events.
// Hooks may be called concurrently.
type TraceHooks struct {
GetConn func(TraceGetConnInfo)
GotConn func(TraceGotConnInfo)
PutIdleConn func(TracePutIdleConnInfo)
DNSStart func(TraceDNSStartInfo)
DNSDone func(TraceDNSDoneInfo)
ConnectStart func(TraceConnectStartInfo)
ConnectDone func(TraceConnectDoneInfo)
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
WroteHeaderField func(TraceWroteHeaderFieldInfo)
WroteHeaders func()
WroteRequest func(TraceWroteRequestInfo)
GotFirstResponseByte func()
RetryAttemptStart func(TraceRetryAttemptStartInfo)
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
RetryBackoff func(TraceRetryBackoffInfo)
}
type TraceGetConnInfo struct {
Addr string
}
type TraceGotConnInfo struct {
Conn net.Conn
Reused bool
WasIdle bool
IdleTime time.Duration
}
type TracePutIdleConnInfo struct {
Err error
}
type TraceDNSStartInfo struct {
Host string
}
type TraceDNSDoneInfo struct {
Addrs []net.IPAddr
Coalesced bool
Err error
}
type TraceConnectStartInfo struct {
Network string
Addr string
}
type TraceConnectDoneInfo struct {
Network string
Addr string
Err error
}
type TraceTLSHandshakeStartInfo struct {
Network string
Addr string
ServerName string
}
type TraceTLSHandshakeDoneInfo struct {
Network string
Addr string
ServerName string
ConnectionState tls.ConnectionState
Err error
}
type TraceWroteHeaderFieldInfo struct {
Key string
Values []string
}
type TraceWroteRequestInfo struct {
Err error
}
type TraceRetryAttemptStartInfo struct {
Attempt int
MaxAttempts int
}
type TraceRetryAttemptDoneInfo struct {
Attempt int
MaxAttempts int
StatusCode int
Err error
WillRetry bool
}
type TraceRetryBackoffInfo struct {
Attempt int
Delay time.Duration
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
}
func newTraceState(hooks *TraceHooks) *traceState {
if hooks == nil {
return nil
}
return &traceState{hooks: hooks}
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
}
return context.WithValue(ctx, traceContextKey{}, state)
}
func getTraceState(ctx context.Context) *traceState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(traceContextKey{}).(*traceState)
return state
}
func (t *traceState) needsHTTPTrace() bool {
if t == nil || t.hooks == nil {
return false
}
h := t.hooks
return h.GetConn != nil ||
h.GotConn != nil ||
h.PutIdleConn != nil ||
h.DNSStart != nil ||
h.DNSDone != nil ||
h.ConnectStart != nil ||
h.ConnectDone != nil ||
h.TLSHandshakeStart != nil ||
h.TLSHandshakeDone != nil ||
h.WroteHeaderField != nil ||
h.WroteHeaders != nil ||
h.WroteRequest != nil ||
h.GotFirstResponseByte != nil
}
func (t *traceState) clientTrace() *httptrace.ClientTrace {
if !t.needsHTTPTrace() {
return nil
}
h := t.hooks
trace := &httptrace.ClientTrace{}
if h.GetConn != nil {
trace.GetConn = func(hostPort string) {
h.GetConn(TraceGetConnInfo{Addr: hostPort})
}
}
if h.GotConn != nil {
trace.GotConn = func(info httptrace.GotConnInfo) {
h.GotConn(TraceGotConnInfo{
Conn: info.Conn,
Reused: info.Reused,
WasIdle: info.WasIdle,
IdleTime: info.IdleTime,
})
}
}
if h.PutIdleConn != nil {
trace.PutIdleConn = func(err error) {
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
}
}
if h.DNSStart != nil {
trace.DNSStart = func(info httptrace.DNSStartInfo) {
if t.usesManualDNS() {
return
}
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
}
}
if h.DNSDone != nil {
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
if t.usesManualDNS() {
return
}
h.DNSDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), info.Addrs...),
Coalesced: info.Coalesced,
Err: info.Err,
})
}
}
if h.ConnectStart != nil {
trace.ConnectStart = func(network, addr string) {
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
}
if h.ConnectDone != nil {
trace.ConnectDone = func(network, addr string, err error) {
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
}
if h.TLSHandshakeStart != nil {
trace.TLSHandshakeStart = func() {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
}
}
if h.TLSHandshakeDone != nil {
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
ConnectionState: state,
Err: err,
})
}
}
if h.WroteHeaderField != nil {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: value,
})
}
}
if h.WroteHeaders != nil {
trace.WroteHeaders = h.WroteHeaders
}
if h.WroteRequest != nil {
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
}
}
if h.GotFirstResponseByte != nil {
trace.GotFirstResponseByte = h.GotFirstResponseByte
}
return trace
}
func (t *traceState) markCustomTLS() {
if t == nil {
return
}
t.customTLS.Store(1)
}
func (t *traceState) usesCustomTLS() bool {
if t == nil {
return false
}
return t.customTLS.Load() != 0
}
func (t *traceState) beginManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(1)
}
func (t *traceState) endManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(-1)
}
func (t *traceState) usesManualDNS() bool {
if t == nil {
return false
}
return t.manualDNSRefs.Load() > 0
}
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
return
}
t.hooks.TLSHandshakeStart(info)
}
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
return
}
t.hooks.TLSHandshakeDone(info)
}
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
return
}
t.hooks.DNSStart(info)
}
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
return
}
t.hooks.DNSDone(info)
}
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
if hooks == nil || hooks.RetryAttemptStart == nil {
return
}
hooks.RetryAttemptStart(info)
}
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
if hooks == nil || hooks.RetryAttemptDone == nil {
return
}
hooks.RetryAttemptDone(info)
}
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
if hooks == nil || hooks.RetryBackoff == nil {
return
}
hooks.RetryBackoff(info)
}

324
trace_test.go Normal file
View File

@ -0,0 +1,324 @@
package starnet
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
)
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
var mu sync.Mutex
events := map[string]int{}
hooks := &TraceHooks{
GetConn: func(info TraceGetConnInfo) {
mu.Lock()
events["get_conn"]++
mu.Unlock()
},
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
events["got_conn"]++
mu.Unlock()
},
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
events["tls_start"]++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
events["tls_done"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected tls handshake error: %v", info.Err)
}
},
WroteHeaders: func() {
mu.Lock()
events["wrote_headers"]++
mu.Unlock()
},
WroteRequest: func(info TraceWroteRequestInfo) {
mu.Lock()
events["wrote_request"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected write error: %v", info.Err)
}
},
GotFirstResponseByte: func() {
mu.Lock()
events["first_byte"]++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
if events[key] == 0 {
t.Fatalf("expected trace event %q", key)
}
}
}
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
tlsStartCount := 0
tlsDoneCount := 0
var lastInfo TraceTLSHandshakeDoneInfo
hooks := &TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
tlsStartCount++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
tlsDoneCount++
lastInfo = info
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetDialTimeout(1500 * time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if tlsStartCount != 1 {
t.Fatalf("tlsStartCount=%d", tlsStartCount)
}
if tlsDoneCount != 1 {
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
}
if lastInfo.Err != nil {
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
}
if lastInfo.ConnectionState.Version == 0 {
t.Fatal("expected tls connection state")
}
}
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var mu sync.Mutex
dnsStartCount := 0
dnsDoneCount := 0
var dnsStartHost string
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
dnsStartHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(url, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if dnsStartHost != "trace.example.test" {
t.Fatalf("dnsStartHost=%q", dnsStartHost)
}
}
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
connectStartCount := 0
connectDoneCount := 0
hooks := &TraceHooks{
ConnectStart: func(info TraceConnectStartInfo) {
mu.Lock()
connectStartCount++
mu.Unlock()
},
ConnectDone: func(info TraceConnectDoneInfo) {
mu.Lock()
connectDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected connect error: %v", info.Err)
}
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(context.Background(), network, addr)
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if connectStartCount != 1 {
t.Fatalf("connectStartCount=%d", connectStartCount)
}
if connectDoneCount != 1 {
t.Fatalf("connectDoneCount=%d", connectDoneCount)
}
}
func TestTraceHooksRetryEvents(t *testing.T) {
var hits int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
if hits == 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
starts := 0
dones := 0
backoffs := 0
var finalDone TraceRetryAttemptDoneInfo
hooks := &TraceHooks{
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
mu.Lock()
starts++
mu.Unlock()
},
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
mu.Lock()
dones++
finalDone = info
mu.Unlock()
},
RetryBackoff: func(info TraceRetryBackoffInfo) {
mu.Lock()
backoffs++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if starts != 2 {
t.Fatalf("starts=%d", starts)
}
if dones != 2 {
t.Fatalf("dones=%d", dones)
}
if backoffs != 1 {
t.Fatalf("backoffs=%d", backoffs)
}
if finalDone.WillRetry {
t.Fatal("expected final attempt not to retry")
}
if finalDone.StatusCode != http.StatusOK {
t.Fatalf("final status=%d", finalDone.StatusCode)
}
}
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
var gotErr error
hooks := &TraceHooks{
DNSDone: func(info TraceDNSDoneInfo) {
gotErr = info.Err
},
}
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return nil, errors.New("lookup failed")
}).
SetTraceHooks(hooks).
Do()
if err == nil {
t.Fatal("expected request error")
}
if gotErr == nil || gotErr.Error() != "lookup failed" {
t.Fatalf("gotErr=%v", gotErr)
}
}

View File

@ -1,61 +1,220 @@
package starnet package starnet
import ( import (
"context"
"crypto/tls"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync" "sync"
"time" "time"
) )
const dynamicTransportCacheMaxEntries = 64
type dynamicTransportCacheKey struct {
proxyKey string
dialTimeout time.Duration
customIPs string
customDNS string
tlsServerName string
skipVerify bool
}
// Transport 自定义 Transport支持请求级配置 // Transport 自定义 Transport支持请求级配置
type Transport struct { type Transport struct {
base *http.Transport base *http.Transport
mu sync.RWMutex dynamicCache map[dynamicTransportCacheKey]*http.Transport
dynamicCacheOrder []dynamicTransportCacheKey
mu sync.RWMutex
} }
// RoundTrip 实现 http.RoundTripper 接口 // RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// 确保 base 已初始化 t.ensureBase()
if t.base == nil {
t.mu.Lock()
if t.base == nil {
t.base = &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
t.mu.Unlock()
}
// 提取请求级别的配置 // 提取请求级别的配置
reqCtx := getRequestContext(req.Context()) reqCtx := getRequestContext(req.Context())
traceState := getTraceState(req.Context())
execReq := req
execReqCtx := reqCtx
var targetAddrs []string
// 优先级1完全自定义的 transport // 优先级1完全自定义的 transport
if reqCtx.Transport != nil { if execReqCtx.Transport != nil {
return reqCtx.Transport.RoundTrip(req) return execReqCtx.Transport.RoundTrip(execReq)
}
var err error
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
if err != nil {
return nil, err
} }
// 优先级2需要动态配置 // 优先级2需要动态配置
if needsDynamicTransport(reqCtx) { if needsDynamicTransport(execReqCtx) {
dynamicTransport := t.buildDynamicTransport(reqCtx) dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
return dynamicTransport.RoundTrip(req) if len(targetAddrs) > 0 {
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
}
return dynamicTransport.RoundTrip(execReq)
} }
// 优先级3使用基础 transport // 优先级3使用基础 transport
t.mu.RLock() t.mu.RLock()
defer t.mu.RUnlock() baseTransport := t.base
return t.base.RoundTrip(req) t.mu.RUnlock()
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
}
return baseTransport.RoundTrip(execReq)
}
func newBaseHTTPTransport() *http.Transport {
return &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
func (t *Transport) ensureBase() {
if t.base != nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
}
func (t *Transport) ensureBaseLocked() {
if t.base == nil {
t.base = newBaseHTTPTransport()
}
}
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
if key, ok := newDynamicTransportCacheKey(rc); ok {
return t.getOrCreateCachedDynamicTransport(key, rc)
}
return t.buildDynamicTransport(rc, traceState)
}
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
t.mu.RLock()
if transport := t.dynamicCache[key]; transport != nil {
t.mu.RUnlock()
return transport
}
t.mu.RUnlock()
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
if transport := t.dynamicCache[key]; transport != nil {
return transport
}
transport := buildDynamicTransportFromBase(t.base, rc, nil)
if t.dynamicCache == nil {
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
}
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
oldestKey := t.dynamicCacheOrder[0]
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
oldest.CloseIdleConnections()
delete(t.dynamicCache, oldestKey)
}
}
t.dynamicCache[key] = transport
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
return transport
}
func (t *Transport) resetDynamicTransportCacheLocked() {
for _, key := range t.dynamicCacheOrder {
if transport := t.dynamicCache[key]; transport != nil {
transport.CloseIdleConnections()
}
}
t.dynamicCache = nil
t.dynamicCacheOrder = nil
}
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
if rc == nil {
return dynamicTransportCacheKey{}, false
}
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
return dynamicTransportCacheKey{}, false
}
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
return dynamicTransportCacheKey{}, false
}
key := dynamicTransportCacheKey{
proxyKey: normalizeProxyCacheKey(rc.Proxy),
dialTimeout: rc.DialTimeout,
customIPs: serializeTransportCacheList(rc.CustomIP),
customDNS: serializeTransportCacheList(rc.CustomDNS),
tlsServerName: effectiveTLSServerName(rc),
}
if rc.TLSConfig != nil {
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
}
return key, true
}
func normalizeProxyCacheKey(proxy string) string {
if proxy == "" {
return ""
}
proxyURL, err := parseProxyURL(proxy)
if err != nil {
return "\x00invalid:" + proxy
}
return proxyURL.String()
}
func serializeTransportCacheList(values []string) string {
if len(values) == 0 {
return ""
}
var builder strings.Builder
for _, value := range values {
builder.WriteString(value)
builder.WriteByte(0)
}
return builder.String()
}
func effectiveTLSServerName(rc *RequestContext) string {
if rc == nil {
return ""
}
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
return rc.TLSConfig.ServerName
}
return rc.TLSServerName
} }
// buildDynamicTransport 构建动态 Transport // buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport { func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
t.ensureBase()
t.mu.RLock() t.mu.RLock()
transport := t.base.Clone() baseTransport := t.base
t.mu.RUnlock() t.mu.RUnlock()
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
}
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
transport := baseTransport.Clone()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify // 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil { if rc.TLSConfig != nil {
@ -64,15 +223,33 @@ func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
// 应用代理配置 // 应用代理配置
if rc.Proxy != "" { if rc.Proxy != "" {
proxyURL, err := url.Parse(rc.Proxy) proxyURL, err := parseProxyURL(rc.Proxy)
if err == nil { if err != nil {
transport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, err
}
} else {
transport.Proxy = http.ProxyURL(proxyURL) transport.Proxy = http.ProxyURL(proxyURL)
} }
} }
// 应用自定义 Dial 函数 // 应用自定义 Dial 函数
if rc.DialFn != nil { if rc.DialFn != nil {
transport.DialContext = rc.DialFn if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
dialFn := rc.DialFn
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if traceState.hooks.ConnectStart != nil {
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
conn, err := dialFn(ctx, network, addr)
if traceState.hooks.ConnectDone != nil {
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
return conn, err
}
} else {
transport.DialContext = rc.DialFn
}
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil { } else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置) // 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc transport.DialContext = defaultDialFunc
@ -93,5 +270,147 @@ func (t *Transport) Base() *http.Transport {
func (t *Transport) SetBase(base *http.Transport) { func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock() t.mu.Lock()
t.base = base t.base = base
t.resetDynamicTransportCacheLocked()
t.mu.Unlock() t.mu.Unlock()
} }
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
if req == nil || req.URL == nil || reqCtx == nil {
return req, reqCtx, nil, nil
}
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
return req, reqCtx, nil, nil
}
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
return req, reqCtx, nil, nil
}
host := req.URL.Hostname()
if host == "" {
return req, reqCtx, nil, nil
}
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
if err != nil {
return nil, nil, nil, err
}
if len(targetAddrs) == 0 {
return req, reqCtx, nil, nil
}
execReqCtx := *reqCtx
execReqCtx.CustomIP = nil
execReqCtx.CustomDNS = nil
execReqCtx.LookupIPFn = nil
if req.URL.Scheme == "https" {
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
execReqCtx.TLSConfigCacheable = true
}
}
execCtx := clearTargetResolutionContext(req.Context())
execReq := req.Clone(execCtx)
execReq.Host = req.Host
if len(targetAddrs) == 1 {
execReq.URL.Host = targetAddrs[0]
return execReq, &execReqCtx, nil, nil
}
return execReq, &execReqCtx, targetAddrs, nil
}
func clearTargetResolutionContext(ctx context.Context) context.Context {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
cloned := cloneRequestContext(rc)
cloned.CustomIP = nil
cloned.CustomDNS = nil
cloned.LookupIPFn = nil
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
}
}
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
return ctx
}
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
if serverName == "" {
return cfg
}
if cfg != nil {
if cfg.ServerName != "" {
return cfg
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
return &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
}
}
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
return rt.RoundTrip(baseReq)
}
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
targetAddrs = targetAddrs[:1]
}
var lastErr error
for _, targetAddr := range targetAddrs {
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
if err != nil {
return nil, err
}
resp, err := rt.RoundTrip(attemptReq)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
if req == nil {
return false
}
if !isIdempotentMethod(req.Method) {
return false
}
if req.Body == nil || req.Body == http.NoBody {
return true
}
return req.GetBody != nil
}
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
req := baseReq.Clone(baseReq.Context())
switch {
case baseReq.Body == nil || baseReq.Body == http.NoBody:
req.Body = baseReq.Body
case baseReq.GetBody != nil:
body, err := baseReq.GetBody()
if err != nil {
return nil, wrapError(err, "clone request body for resolved target")
}
req.Body = body
default:
req.Body = baseReq.Body
}
req.URL.Host = targetAddr
req.Host = baseReq.Host
return req, nil
}

224
transport_cache_test.go Normal file
View File

@ -0,0 +1,224 @@
package starnet
import (
"crypto/tls"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
if first != second {
t.Fatal("expected cached dynamic transport to be reused")
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "first.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "second.test",
}, nil)
if first == second {
t.Fatal("expected distinct tls server names to use different transports")
}
if got := len(transport.dynamicCache); got != 2 {
t.Fatalf("dynamic cache size=%d; want 2", got)
}
}
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
reqCtx := &RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
first := transport.getDynamicTransport(reqCtx, nil)
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected user tls config to bypass dynamic transport cache")
}
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0", got)
}
}
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
first := transport.getDynamicTransport(reqCtx, nil)
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
}
client.SetDefaultSkipTLSVerify(true)
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
}
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected cache reset after default tls change")
}
}
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
client := NewClientNoErr()
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
runRequest := func() bool {
var (
mu sync.Mutex
gotConn bool
reused bool
)
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetTraceHooks(&TraceHooks{
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
gotConn = true
reused = info.Reused
mu.Unlock()
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
mu.Lock()
defer mu.Unlock()
if !gotConn {
t.Fatal("expected GotConn trace event")
}
return reused
}
if runRequest() {
t.Fatal("first request unexpectedly reused a connection")
}
if !runRequest() {
t.Fatal("second request did not reuse cached dynamic transport connection")
}
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if execReq == req {
t.Fatal("expected cloned request for proxy target preparation")
}
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
}
if got := req.URL.Host; got != "proxy-single.test:8443" {
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
}
if len(targetAddrs) != 0 {
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
}
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
t.Fatal("expected synthesized tls config for single target proxy request")
}
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
}
}
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
t.Fatalf("execReq.URL.Host=%q; want original host", got)
}
if len(targetAddrs) != 2 {
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
}
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
}
}

View File

@ -53,6 +53,7 @@ type NetworkConfig struct {
type TLSConfig struct { type TLSConfig struct {
Config *tls.Config // TLS 配置 Config *tls.Config // TLS 配置
SkipVerify bool // 跳过证书验证 SkipVerify bool // 跳过证书验证
ServerName string // 显式 TLS ServerName/SNI 覆盖
} }
// DNSConfig DNS 配置 // DNSConfig DNS 配置
@ -62,8 +63,19 @@ type DNSConfig struct {
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数 LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
} }
type bodyMode uint8
const (
bodyModeUnset bodyMode = iota
bodyModeBytes
bodyModeReader
bodyModeForm
bodyModeMultipart
)
// BodyConfig 请求体配置 // BodyConfig 请求体配置
type BodyConfig struct { type BodyConfig struct {
Mode bodyMode // 当前 body 来源模式
Bytes []byte // 原始字节 Bytes []byte // 原始字节
Reader io.Reader // 数据流 Reader io.Reader // 数据流
FormData map[string][]string // 表单数据 FormData map[string][]string // 表单数据
@ -82,6 +94,7 @@ type RequestConfig struct {
// 其他配置 // 其他配置
BasicAuth [2]string // Basic 认证 BasicAuth [2]string // Basic 认证
Host string // 显式 Host 头覆盖
ContentLength int64 // 手动设置的 Content-Length ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length AutoCalcContentLength bool // 自动计算 Content-Length
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制) MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
@ -104,6 +117,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
TLS: TLSConfig{ TLS: TLSConfig{
Config: cloneTLSConfig(c.TLS.Config), Config: cloneTLSConfig(c.TLS.Config),
SkipVerify: c.TLS.SkipVerify, SkipVerify: c.TLS.SkipVerify,
ServerName: c.TLS.ServerName,
}, },
DNS: DNSConfig{ DNS: DNSConfig{
CustomIP: cloneStringSlice(c.DNS.CustomIP), CustomIP: cloneStringSlice(c.DNS.CustomIP),
@ -111,6 +125,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
LookupFunc: c.DNS.LookupFunc, LookupFunc: c.DNS.LookupFunc,
}, },
Body: BodyConfig{ Body: BodyConfig{
Mode: c.Body.Mode,
Bytes: cloneBytes(c.Body.Bytes), Bytes: cloneBytes(c.Body.Bytes),
Reader: c.Body.Reader, // Reader 不可克隆 Reader: c.Body.Reader, // Reader 不可克隆
FormData: cloneStringMapSlice(c.Body.FormData), FormData: cloneStringMapSlice(c.Body.FormData),
@ -120,6 +135,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
Cookies: cloneCookies(c.Cookies), Cookies: cloneCookies(c.Cookies),
Queries: cloneStringMapSlice(c.Queries), Queries: cloneStringMapSlice(c.Queries),
BasicAuth: c.BasicAuth, BasicAuth: c.BasicAuth,
Host: c.Host,
ContentLength: c.ContentLength, ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength, AutoCalcContentLength: c.AutoCalcContentLength,
MaxRespBodyBytes: c.MaxRespBodyBytes, MaxRespBodyBytes: c.MaxRespBodyBytes,

View File

@ -101,24 +101,31 @@ func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
} }
newCookies := make([]*http.Cookie, len(cookies)) newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies { for i, c := range cookies {
newCookies[i] = &http.Cookie{ newCookies[i] = cloneCookie(c)
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: c.SameSite,
Raw: c.Raw,
Unparsed: append([]string(nil), c.Unparsed...),
}
} }
return newCookies return newCookies
} }
func cloneCookie(cookie *http.Cookie) *http.Cookie {
if cookie == nil {
return nil
}
return &http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
Path: cookie.Path,
Domain: cookie.Domain,
Expires: cookie.Expires,
RawExpires: cookie.RawExpires,
MaxAge: cookie.MaxAge,
Secure: cookie.Secure,
HttpOnly: cookie.HttpOnly,
SameSite: cookie.SameSite,
Raw: cookie.Raw,
Unparsed: append([]string(nil), cookie.Unparsed...),
}
}
// cloneStringMapSlice 克隆 map[string][]string // cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string { func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil { if m == nil {
@ -171,8 +178,8 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
// copyWithProgress 带进度的复制 // copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) { func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if progress == nil { if ctx == nil {
return io.Copy(dst, src) ctx = context.Background()
} }
var written int64 var written int64
@ -190,8 +197,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
nw, ew := dst.Write(buf[:nr]) nw, ew := dst.Write(buf[:nr])
if nw > 0 { if nw > 0 {
written += int64(nw) written += int64(nw)
// 同步调用进度回调(不使用 goroutine if progress != nil {
progress(filename, written, total) // 同步调用进度回调(不使用 goroutine
progress(filename, written, total)
}
} }
if ew != nil { if ew != nil {
return written, ew return written, ew
@ -202,8 +211,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
} }
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
// 最后一次进度回调 if progress != nil {
progress(filename, written, total) // 最后一次进度回调
progress(filename, written, total)
}
return written, nil return written, nil
} }
return written, err return written, err