From 732e81316ce79c05624b81ce39c5247f3ceed92f Mon Sep 17 00:00:00 2001 From: starainrt Date: Sun, 19 Apr 2026 15:39:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(starnet):=20=E9=87=8D=E6=9E=84=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E6=89=A7=E8=A1=8C=E9=93=BE=E8=B7=AF=E5=B9=B6=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E4=BB=A3=E7=90=86/=E9=87=8D=E8=AF=95/trace=E8=BE=B9?= =?UTF-8?q?=E7=95=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明 --- .gitignore | 2 + README.md | 89 +++--- benchmark_test.go | 27 ++ client.go | 24 +- context.go | 167 +++++++---- defensive_copy_test.go | 52 ++++ dialer.go | 176 ++++++++---- dynamic_transport_benchmark_test.go | 144 ++++++++++ host_tls_regression_test.go | 150 ++++++++++ internal/pingcore/core.go | 230 +++++++++++++++ internal/tlssniffercore/config.go | 123 ++++++++ internal/tlssniffercore/parser.go | 237 ++++++++++++++++ options.go | 405 --------------------------- options_body.go | 112 ++++++++ options_config.go | 132 +++++++++ options_header.go | 99 +++++++ ping.go | 221 ++------------- proxy_custom_ip_test.go | 110 ++++++++ proxy_local_helpers_test.go | 331 ++++++++++++++++++++++ request.go | 208 ++++++++++++-- request_body.go | 276 ++---------------- request_config.go | 282 ------------------- request_execution.go | 34 +++ request_execution_regression_test.go | 172 ++++++++++++ request_header.go | 92 +++++- request_multipart.go | 69 +++++ request_mutation.go | 326 +++++++++++++++++++++ request_network.go | 71 +++++ request_prepare.go | 314 +++++++++++++++++++++ request_prepare_regression_test.go | 335 ++++++++++++++++++++++ request_query.go | 31 ++ request_state_boundary_test.go | 168 +++++++++++ request_trace.go | 6 + retry.go | 85 ++++-- review_regression_test.go | 244 ++++++++++++++++ tls_test.go | 125 ++++++++- tlssniffer.go | 360 +++--------------------- trace.go | 340 ++++++++++++++++++++++ trace_test.go | 324 +++++++++++++++++++++ transport.go | 377 +++++++++++++++++++++++-- transport_cache_test.go | 224 +++++++++++++++ types.go | 16 ++ utils.go | 51 ++-- 43 files changed, 5633 insertions(+), 1728 deletions(-) create mode 100644 dynamic_transport_benchmark_test.go create mode 100644 host_tls_regression_test.go create mode 100644 internal/pingcore/core.go create mode 100644 internal/tlssniffercore/config.go create mode 100644 internal/tlssniffercore/parser.go delete mode 100644 options.go create mode 100644 options_body.go create mode 100644 options_config.go create mode 100644 options_header.go create mode 100644 proxy_custom_ip_test.go create mode 100644 proxy_local_helpers_test.go delete mode 100644 request_config.go create mode 100644 request_execution.go create mode 100644 request_execution_regression_test.go create mode 100644 request_multipart.go create mode 100644 request_mutation.go create mode 100644 request_network.go create mode 100644 request_prepare.go create mode 100644 request_prepare_regression_test.go create mode 100644 request_query.go create mode 100644 request_state_boundary_test.go create mode 100644 request_trace.go create mode 100644 review_regression_test.go create mode 100644 trace.go create mode 100644 trace_test.go create mode 100644 transport_cache_test.go diff --git a/.gitignore b/.gitignore index fa67efa..120a5e5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ .sentrux/ agent_readme.md target.md +agents.md +.codex \ No newline at end of file diff --git a/README.md b/README.md index 6123cb4..d4bca50 100644 --- a/README.md +++ b/README.md @@ -1,58 +1,58 @@ # starnet -`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities. +`starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。 -## Highlights +## 功能概览 -- Request-level timeout by context (without mutating shared `http.Client` timeout) -- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config -- Built-in retry with replay safety checks and configurable backoff/jitter/statuses -- Response body safety guard via max body bytes limit -- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`) -- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios -- ICMP ping with IPv4/IPv6 target handling and option-based probing API +- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时 +- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置 +- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定 +- 响应体大小限制,避免一次性读取过大内容 +- 错误分类辅助:`ClassifyError`、`IsTimeout`、`IsDNS`、`IsTLS`、`IsProxy`、`IsCanceled` +- TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景 +- ICMP Ping,支持 IPv4 / IPv6 目标和选项化探测 -## Main Features +## 主要能力 -### HTTP Client and Request +### HTTP 客户端与请求构建 -- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods -- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect` -- Request body helpers: JSON, form data, multipart file upload, stream body -- Header/cookie/query helpers with defensive copy on key setters -- Request cloning for safe reuse in concurrent or variant calls +- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口 +- 支持 `Get`、`Post`、`Put`、`Delete`、`Head`、`Patch`、`Options`、`Trace`、`Connect` +- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态 +- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险 +- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造 -### Timeout and Retry +### 超时与重试 -- Request timeout is applied by context deadline, not global client timeout -- Retry supports: - - max attempts - - backoff factor/base/max - - jitter - - retry status whitelist - - idempotent-only guard - - custom retry-on-error callback -- Retry keeps original request pointer in final response for consistency +- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置 +- 重试支持: + - 最大尝试次数 + - 基础退避、最大退避和退避因子 + - 抖动比例 + - 可重试状态码集合 + - 仅幂等方法重试 + - 自定义错误判定函数 +- 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用 -### Response Handling +### 响应处理 -- `Bytes/String/JSON/Reader` helpers -- optional auto-fetch mode -- configurable max response body bytes to prevent oversized reads +- 提供 `Bytes`、`String`、`JSON`、`Reader` 等响应体读取接口 +- 支持自动预取响应体 +- 支持按字节数限制响应体读取上限 -### Ping Module +### Ping 模块 -- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable` -- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size -- explicit error semantics for permission/protocol/timeout/resolve failures +- 提供 `Ping`、`PingWithContext`、`Pingable` 以及兼容函数 `IsIpPingable` +- `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数 +- 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义 -## Install +## 安装 ```bash go get b612.me/starnet ``` -## Quick Example +## 快速示例 ```go package main @@ -94,13 +94,18 @@ func main() { } ``` -## Stability Notes +## 行为说明 -- Raw ICMP ping may require elevated privileges on some systems. -- Integration tests that rely on external network are environment-dependent. +- `NewClient`、`NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。 +- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`。 +- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader` 或 `WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。 +- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。 -## License +## 稳定性说明 -This project is licensed under the Apache License 2.0. -See [LICENSE](./LICENSE). +- 原始 ICMP Ping 在部分系统上需要额外权限。 +- 依赖外部网络环境的集成测试结果可能受运行环境影响。 +## 许可证 + +本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。 diff --git a/benchmark_test.go b/benchmark_test.go index b994d51..bc56ac4 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -148,6 +148,33 @@ func BenchmarkRequestCreation(b *testing.B) { } } +func BenchmarkRequestPrepareDefaultPath(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + req := NewSimpleRequest("https://example.com", "GET") + if err := req.prepare(); err != nil { + b.Fatalf("prepare() error: %v", err) + } + } +} + +func BenchmarkRequestPrepareDynamicPath(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + req := NewSimpleRequest("https://example.com", "GET", + WithCustomIP([]string{"127.0.0.1"}), + WithSkipTLSVerify(true), + ) + if err := req.prepare(); err != nil { + b.Fatalf("prepare() error: %v", err) + } + } +} + func BenchmarkResponseBodyRead(b *testing.B) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test response data")) diff --git a/client.go b/client.go index 3bc1eb4..c6235b4 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "sync" - "time" ) // Client HTTP 客户端封装 @@ -19,14 +18,7 @@ type Client struct { // NewClient 创建新的 Client func NewClient(opts ...RequestOpt) (*Client, error) { // 创建基础 Transport - baseTransport := &http.Transport{ - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } + baseTransport := newBaseHTTPTransport() httpClient := &http.Client{ Transport: &Transport{base: baseTransport}, @@ -40,6 +32,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) { if err != nil { return nil, wrapError(err, "create client") } + if req.err != nil { + return nil, wrapError(req.err, "create client") + } /* // 如果选项中有自定义配置,应用到 httpClient @@ -61,7 +56,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) { }, nil } -// NewClientNoErr 创建新的 Client(忽略错误) +// NewClientNoErr 创建新的 Client(忽略错误)。 +// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象; +// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。 func NewClientNoErr(opts ...RequestOpt) *Client { client, _ := NewClient(opts...) if client == nil { @@ -172,11 +169,13 @@ func (c *Client) Clone() *Client { func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client { if transport, ok := c.client.Transport.(*Transport); ok { transport.mu.Lock() + transport.ensureBaseLocked() if tlsConfig != nil { transport.base.TLSClientConfig = tlsConfig.Clone() } else { transport.base.TLSClientConfig = nil } + transport.resetDynamicTransportCacheLocked() transport.mu.Unlock() } return c @@ -186,12 +185,14 @@ func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client { func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client { if transport, ok := c.client.Transport.(*Transport); ok { transport.mu.Lock() + transport.ensureBaseLocked() if transport.base.TLSClientConfig == nil { transport.base.TLSClientConfig = &tls.Config{} } else { transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone() } transport.base.TLSClientConfig.InsecureSkipVerify = skip + transport.resetDynamicTransportCacheLocked() transport.mu.Unlock() } return c @@ -227,6 +228,9 @@ func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, if err != nil { return nil, err } + if req.err != nil { + return nil, req.err + } req.client = c req.httpClient = c.client diff --git a/context.go b/context.go index 69e6458..62f57db 100644 --- a/context.go +++ b/context.go @@ -14,6 +14,8 @@ type contextKey int const ( ctxKeyTransport contextKey = iota ctxKeyTLSConfig + ctxKeyTLSConfigCacheable + ctxKeyTLSServerName ctxKeyProxy ctxKeyCustomIP ctxKeyCustomDNS @@ -21,58 +23,95 @@ const ( ctxKeyTimeout ctxKeyLookupIP ctxKeyDialFunc + ctxKeyRequestContext ) // RequestContext 从 context 中提取的请求配置 type RequestContext struct { - Transport *http.Transport - TLSConfig *tls.Config - Proxy string - CustomIP []string - CustomDNS []string - DialTimeout time.Duration - Timeout time.Duration - LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) - DialFn func(ctx context.Context, network, addr string) (net.Conn, error) + Transport *http.Transport + TLSConfig *tls.Config + TLSConfigCacheable bool + TLSServerName string + Proxy string + CustomIP []string + CustomDNS []string + DialTimeout time.Duration + Timeout time.Duration + LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) + DialFn func(ctx context.Context, network, addr string) (net.Conn, error) } +var emptyRequestContext = &RequestContext{} + // getRequestContext 从 context 中提取请求配置 func getRequestContext(ctx context.Context) *RequestContext { - rc := &RequestContext{} + if v := ctx.Value(ctxKeyRequestContext); v != nil { + if rc, ok := v.(*RequestContext); ok && rc != nil { + return rc + } + } + var rc *RequestContext + ensure := func() *RequestContext { + if rc == nil { + rc = &RequestContext{} + } + return rc + } if v := ctx.Value(ctxKeyTransport); v != nil { - rc.Transport, _ = v.(*http.Transport) + ensure().Transport, _ = v.(*http.Transport) } if v := ctx.Value(ctxKeyTLSConfig); v != nil { - rc.TLSConfig, _ = v.(*tls.Config) + ensure().TLSConfig, _ = v.(*tls.Config) + } + if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil { + ensure().TLSConfigCacheable, _ = v.(bool) + } + if v := ctx.Value(ctxKeyTLSServerName); v != nil { + ensure().TLSServerName, _ = v.(string) } if v := ctx.Value(ctxKeyProxy); v != nil { - rc.Proxy, _ = v.(string) + ensure().Proxy, _ = v.(string) } if v := ctx.Value(ctxKeyCustomIP); v != nil { - rc.CustomIP, _ = v.([]string) + ensure().CustomIP, _ = v.([]string) } if v := ctx.Value(ctxKeyCustomDNS); v != nil { - rc.CustomDNS, _ = v.([]string) + ensure().CustomDNS, _ = v.([]string) } if v := ctx.Value(ctxKeyDialTimeout); v != nil { - rc.DialTimeout, _ = v.(time.Duration) + ensure().DialTimeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyTimeout); v != nil { - rc.Timeout, _ = v.(time.Duration) + ensure().Timeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyLookupIP); v != nil { - rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) + ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) } if v := ctx.Value(ctxKeyDialFunc); v != nil { - rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) + ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) + } + if rc == nil { + return emptyRequestContext } - return rc } +func cloneRequestContext(rc *RequestContext) *RequestContext { + if rc == nil { + return nil + } + cloned := *rc + cloned.CustomIP = cloneStringSlice(rc.CustomIP) + cloned.CustomDNS = cloneStringSlice(rc.CustomDNS) + return &cloned +} + // needsDynamicTransport 判断是否需要动态 Transport func needsDynamicTransport(rc *RequestContext) bool { + if rc == nil { + return false + } return rc.Transport != nil || rc.TLSConfig != nil || rc.Proxy != "" || @@ -83,63 +122,67 @@ func needsDynamicTransport(rc *RequestContext) bool { rc.LookupIPFn != nil } -// injectRequestConfig 将请求配置注入到 context -func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context { - execCtx := ctx +func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext { + if config == nil { + return nil + } + + rc := &RequestContext{ + DialTimeout: config.Network.DialTimeout, + Timeout: config.Network.Timeout, + } // 处理 TLS 配置 var tlsConfig *tls.Config + tlsConfigCacheable := false if config.TLS.Config != nil { tlsConfig = config.TLS.Config.Clone() - if config.TLS.SkipVerify { - tlsConfig.InsecureSkipVerify = true - } - } else if config.TLS.SkipVerify { + } else if config.TLS.SkipVerify || config.TLS.ServerName != "" { tlsConfig = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - InsecureSkipVerify: true, + NextProtos: []string{"h2", "http/1.1"}, } + tlsConfigCacheable = true } + if config.TLS.SkipVerify && tlsConfig != nil { + tlsConfig.InsecureSkipVerify = true + } + if config.TLS.ServerName != "" && tlsConfig != nil { + tlsConfig.ServerName = config.TLS.ServerName + } if tlsConfig != nil { - execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig) + rc.TLSConfig = tlsConfig + rc.TLSConfigCacheable = tlsConfigCacheable + } + if config.TLS.ServerName != "" { + rc.TLSServerName = config.TLS.ServerName + } else if defaultTLSServerName != "" { + rc.TLSServerName = defaultTLSServerName } - // 注入代理 - if config.Network.Proxy != "" { - execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy) - } + rc.Proxy = config.Network.Proxy + rc.CustomIP = cloneStringSlice(config.DNS.CustomIP) + rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS) + rc.LookupIPFn = config.DNS.LookupFunc + rc.DialFn = config.Network.DialFunc - // 注入自定义 IP - if len(config.DNS.CustomIP) > 0 { - execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP) - } - - // 注入自定义 DNS - if len(config.DNS.CustomDNS) > 0 { - execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS) - } - - // 总是注入 DialTimeout(与原始代码一致) - if config.Network.DialTimeout > 0 { - execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout) - } - - // 注入 DNS 解析函数 - if config.DNS.LookupFunc != nil { - execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc) - } - - // 注入 Dial 函数 - if config.Network.DialFunc != nil { - execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc) - } - - // 注入自定义 Transport if config.CustomTransport && config.Transport != nil { - execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport) + rc.Transport = config.Transport } - return execCtx + if !needsDynamicTransport(rc) { + return nil + } + + return rc +} + +// injectRequestConfig 将请求配置注入到 context +func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context { + rc := buildRequestContext(config, defaultTLSServerName) + if rc == nil { + return ctx + } + return context.WithValue(ctx, ctxKeyRequestContext, rc) } diff --git a/defensive_copy_test.go b/defensive_copy_test.go index 3d0fdcc..b201fc2 100644 --- a/defensive_copy_test.go +++ b/defensive_copy_test.go @@ -57,3 +57,55 @@ func TestSetFormDataDefensiveCopy(t *testing.T) { t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got) } } + +func TestWithBodyDefensiveCopy(t *testing.T) { + body := []byte("hello") + + req, err := NewRequest("http://example.com", "POST", WithBody(body)) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + + body[0] = 'j' + if string(req.config.Body.Bytes) != "hello" { + t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello") + } +} + +func TestWithFormDataDefensiveCopy(t *testing.T) { + form := map[string][]string{ + "name": []string{"alice"}, + } + + req, err := NewRequest("http://example.com", "POST", WithFormData(form)) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + + form["name"][0] = "bob" + form["name"] = append(form["name"], "carol") + got := req.config.Body.FormData["name"] + if len(got) != 1 || got[0] != "alice" { + t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got) + } +} + +func TestSetCustomIPDefensiveCopy(t *testing.T) { + ips := []string{"1.1.1.1", "8.8.8.8"} + req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips) + + ips[0] = "9.9.9.9" + if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" { + t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1") + } +} + +func TestSetCustomDNSDefensiveCopy(t *testing.T) { + servers := []string{"8.8.8.8", "1.1.1.1"} + req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers) + + servers[0] = "9.9.9.9" + if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" { + t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8") + } +} diff --git a/dialer.go b/dialer.go index d1ea6ce..69e09f5 100644 --- a/dialer.go +++ b/dialer.go @@ -9,10 +9,100 @@ import ( "time" ) +func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) { + if traceState != nil { + traceState.beginManualDNS() + defer traceState.endManualDNS() + traceState.dnsStart(TraceDNSStartInfo{Host: host}) + } + ipAddrs, err := lookup() + if traceState != nil { + traceState.dnsDone(TraceDNSDoneInfo{ + Addrs: append([]net.IPAddr(nil), ipAddrs...), + Err: err, + }) + } + return ipAddrs, err +} + +func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) { + if reqCtx == nil { + reqCtx = &RequestContext{} + } + + var addrs []string + + if len(reqCtx.CustomIP) > 0 { + for _, ip := range reqCtx.CustomIP { + addrs = append(addrs, joinResolvedHostPort(ip, port)) + } + return addrs, nil + } + + var ( + ipAddrs []net.IPAddr + err error + ) + + if reqCtx.LookupIPFn != nil { + ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { + return reqCtx.LookupIPFn(ctx, host) + }) + } else if len(reqCtx.CustomDNS) > 0 { + dialTimeout := reqCtx.DialTimeout + if dialTimeout == 0 { + dialTimeout = DefaultDialTimeout + } + dialer := &net.Dialer{Timeout: dialTimeout} + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var lastErr error + for _, dnsServer := range reqCtx.CustomDNS { + conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53")) + if err != nil { + lastErr = err + continue + } + return conn, nil + } + return nil, lastErr + }, + } + ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { + return resolver.LookupIPAddr(ctx, host) + }) + } else { + ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { + return net.DefaultResolver.LookupIPAddr(ctx, host) + }) + } + + if err != nil { + return nil, wrapError(err, "lookup ip") + } + + for _, ipAddr := range ipAddrs { + addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port)) + } + return addrs, nil +} + +func joinResolvedHostPort(host, port string) string { + if port == "" { + if ip := net.ParseIP(host); ip != nil && ip.To4() == nil { + return "[" + host + "]" + } + return host + } + return net.JoinHostPort(host, port) +} + // defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS) func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 提取配置 reqCtx := getRequestContext(ctx) + traceState := getTraceState(ctx) dialTimeout := reqCtx.DialTimeout if dialTimeout == 0 { @@ -25,52 +115,9 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error return nil, wrapError(err, "split host port") } - // 获取 IP 地址列表 - var addrs []string - - // 优先级1:直接指定的 IP - if len(reqCtx.CustomIP) > 0 { - for _, ip := range reqCtx.CustomIP { - addrs = append(addrs, net.JoinHostPort(ip, port)) - } - } else { - // 优先级2:DNS 解析 - var ipAddrs []net.IPAddr - - // 使用自定义解析函数 - if reqCtx.LookupIPFn != nil { - ipAddrs, err = reqCtx.LookupIPFn(ctx, host) - } else if len(reqCtx.CustomDNS) > 0 { - // 使用自定义 DNS 服务器 - dialer := &net.Dialer{Timeout: dialTimeout} - resolver := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var lastErr error - for _, dnsServer := range reqCtx.CustomDNS { - conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53")) - if err != nil { - lastErr = err - continue - } - return conn, nil - } - return nil, lastErr - }, - } - ipAddrs, err = resolver.LookupIPAddr(ctx, host) - } else { - // 使用默认解析器 - ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host) - } - - if err != nil { - return nil, wrapError(err, "lookup ip") - } - - for _, ipAddr := range ipAddrs { - addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port)) - } + addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState) + if err != nil { + return nil, err } // 尝试连接所有地址 @@ -103,13 +150,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er // 提取 TLS 配置 reqCtx := getRequestContext(ctx) + traceState := getTraceState(ctx) tlsConfig := reqCtx.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{} } - // ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify,自动设置 - if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { + serverName := tlsConfig.ServerName + if serverName == "" { + serverName = reqCtx.TLSServerName + } + if serverName == "" && !tlsConfig.InsecureSkipVerify { host, _, err := net.SplitHostPort(addr) if err != nil { if idx := strings.LastIndex(addr, ":"); idx > 0 { @@ -118,8 +169,19 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er host = addr } } + serverName = host + } + if serverName != "" && tlsConfig.ServerName != serverName { tlsConfig = tlsConfig.Clone() // 避免修改原 config - tlsConfig.ServerName = host + tlsConfig.ServerName = serverName + } + if traceState != nil { + traceState.markCustomTLS() + traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{ + Network: network, + Addr: addr, + ServerName: serverName, + }) } // 执行 TLS 握手 @@ -130,9 +192,25 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { + if traceState != nil { + traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{ + Network: network, + Addr: addr, + ServerName: serverName, + Err: err, + }) + } conn.Close() return nil, wrapError(err, "tls handshake") } + if traceState != nil { + traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{ + Network: network, + Addr: addr, + ServerName: serverName, + ConnectionState: tlsConn.ConnectionState(), + }) + } return tlsConn, nil } diff --git a/dynamic_transport_benchmark_test.go b/dynamic_transport_benchmark_test.go new file mode 100644 index 0000000..2bc82f1 --- /dev/null +++ b/dynamic_transport_benchmark_test.go @@ -0,0 +1,144 @@ +package starnet + +import ( + "crypto/tls" + "net/http" + "net/url" + "testing" +) + +func BenchmarkDynamicTransportCustomIP(b *testing.B) { + server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + defer server.Close() + + targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test") + client := NewClientNoErr() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"})) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + _, _ = resp.Body().Bytes() + resp.Close() + } +} + +func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) { + server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + defer server.Close() + + proxy := newIPv4ConnectProxyServer(b, nil) + defer proxy.Close() + + targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test") + client := NewClientNoErr() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := client.Get(targetURL, + WithProxy(proxy.URL), + WithCustomIP([]string{"127.0.0.1"}), + WithSkipTLSVerify(true), + ) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + _, _ = resp.Body().Bytes() + resp.Close() + } +} + +func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) { + server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + defer server.Close() + + targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test") + client := NewClientNoErr() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := client.Get(targetURL, + WithCustomIP([]string{"127.0.0.1"}), + WithSkipTLSVerify(true), + ) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + _, _ = resp.Body().Bytes() + resp.Close() + } +} + +func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) { + server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + defer server.Close() + + targetURL := httpsURLForHost(b, server, "bench-user-tls.test") + client := NewClientNoErr() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := client.Get(targetURL, + WithCustomIP([]string{"127.0.0.1"}), + WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + ) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + _, _ = resp.Body().Bytes() + resp.Close() + } +} + +func benchmarkTargetURL(tb testing.TB, rawURL, host string) string { + tb.Helper() + + parsed, err := url.Parse(rawURL) + if err != nil { + tb.Fatalf("url.Parse() error: %v", err) + } + + port := parsed.Port() + if port == "" { + switch parsed.Scheme { + case "https": + port = "443" + default: + port = "80" + } + } + + return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery) +} + +func pathWithQuery(path, rawQuery string) string { + if path == "" { + path = "/" + } + if rawQuery == "" { + return path + } + return path + "?" + rawQuery +} diff --git a/host_tls_regression_test.go b/host_tls_regression_test.go new file mode 100644 index 0000000..91540f9 --- /dev/null +++ b/host_tls_regression_test.go @@ -0,0 +1,150 @@ +package starnet + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) { + cfg := &tls.Config{} + + req := NewSimpleRequest("https://example.com", http.MethodGet). + SetTLSConfig(cfg). + SetURL("https://other.example") + + if req.Err() != nil { + t.Fatalf("unexpected request error: %v", req.Err()) + } + if cfg.ServerName != "" { + t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName) + } +} + +func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) { + cfg := &tls.Config{InsecureSkipVerify: true} + + req := NewSimpleRequest("https://example.com", http.MethodGet). + SetTLSConfig(cfg). + SetTLSServerName("override.example") + + if err := req.prepare(); err != nil { + t.Fatalf("prepare error: %v", err) + } + if cfg.ServerName != "" { + t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName) + } + + rc := getRequestContext(req.execCtx) + if rc.TLSConfig == nil { + t.Fatal("expected injected tls config") + } + if rc.TLSConfig == cfg { + t.Fatal("expected injected tls config to be cloned") + } + if rc.TLSConfig.ServerName != "override.example" { + t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName) + } +} + +func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) { + req := NewSimpleRequest("https://example.com", http.MethodGet). + SetTLSServerName("override.example") + + if err := req.prepare(); err != nil { + t.Fatalf("prepare error: %v", err) + } + + rc := getRequestContext(req.execCtx) + if rc.TLSConfig == nil { + t.Fatal("expected injected tls config") + } + if rc.TLSConfig.ServerName != "override.example" { + t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName) + } +} + +func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) { + req := NewSimpleRequest("https://example.com", http.MethodGet) + + if err := req.prepare(); err != nil { + t.Fatalf("prepare error: %v", err) + } + + if got := req.execCtx.Value(ctxKeyRequestContext); got != nil { + t.Fatalf("unexpected request context injection: %#v", got) + } + + rc := getRequestContext(req.execCtx) + if needsDynamicTransport(rc) { + t.Fatalf("default path unexpectedly marked dynamic: %#v", rc) + } + if rc.TLSServerName != "" { + t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName) + } +} + +func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) { + req := NewSimpleRequest("https://example.com", http.MethodGet). + SetCustomIP([]string{"127.0.0.1"}). + SetSkipTLSVerify(true) + + if err := req.prepare(); err != nil { + t.Fatalf("prepare error: %v", err) + } + + raw := req.execCtx.Value(ctxKeyRequestContext) + rc, ok := raw.(*RequestContext) + if !ok || rc == nil { + t.Fatalf("expected aggregated request context, got %#v", raw) + } + if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" { + t.Fatalf("custom ip=%v", rc.CustomIP) + } + if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify { + t.Fatal("expected tls config with skip verify") + } + if rc.TLSServerName != "example.com" { + t.Fatalf("default tls server name=%q", rc.TLSServerName) + } +} + +func TestRequestSetHostOverridesRequestHost(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Host != "override.example" { + t.Fatalf("host=%q", r.Host) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := NewSimpleRequest(s.URL, http.MethodGet). + SetHost("override.example"). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() +} + +func TestWithHostOverridesRequestHost(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Host != "option.example" { + t.Fatalf("host=%q", r.Host) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example")) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + + got, err := resp.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer got.Close() +} diff --git a/internal/pingcore/core.go b/internal/pingcore/core.go new file mode 100644 index 0000000..9ad0450 --- /dev/null +++ b/internal/pingcore/core.go @@ -0,0 +1,230 @@ +package pingcore + +import ( + "encoding/binary" + "net" + "os" + "sync/atomic" + "time" +) + +const icmpHeaderLen = 8 + +type ICMP struct { + Type uint8 + Code uint8 + CheckSum uint16 + Identifier uint16 + SequenceNum uint16 +} + +type Options struct { + Count int + Timeout time.Duration + Interval time.Duration + Deadline time.Time + PreferIPv4 bool + PreferIPv6 bool + SourceIP net.IP + PayloadSize int +} + +type Result struct { + Duration time.Duration + RecvCount int + RemoteIP string +} + +var identifierSeed uint32 + +func NextIdentifier() uint16 { + pid := uint32(os.Getpid() & 0xffff) + n := atomic.AddUint32(&identifierSeed, 1) + return uint16((pid + n) & 0xffff) +} + +func Payload(size int) []byte { + if size <= 0 { + return nil + } + payload := make([]byte, size) + for index := 0; index < len(payload); index++ { + payload[index] = byte(index) + } + return payload +} + +func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { + icmp := ICMP{ + Type: typ, + Code: 0, + CheckSum: 0, + Identifier: identifier, + SequenceNum: seq, + } + buf := MarshalPacket(icmp, payload) + icmp.CheckSum = Checksum(buf) + return icmp +} + +func Checksum(data []byte) uint16 { + var ( + sum uint32 + length = len(data) + index int + ) + for length > 1 { + sum += uint32(data[index])<<8 + uint32(data[index+1]) + index += 2 + length -= 2 + } + if length > 0 { + sum += uint32(data[index]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(^sum) +} + +func Marshal(icmp ICMP) []byte { + return MarshalPacket(icmp, nil) +} + +func MarshalPacket(icmp ICMP, payload []byte) []byte { + buf := make([]byte, icmpHeaderLen+len(payload)) + buf[0] = icmp.Type + buf[1] = icmp.Code + binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum) + binary.BigEndian.PutUint16(buf[4:], icmp.Identifier) + binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum) + copy(buf[icmpHeaderLen:], payload) + return buf +} + +func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { + for _, offset := range CandidateICMPOffsets(packet, family) { + if offset < 0 || offset+icmpHeaderLen > len(packet) { + continue + } + if packet[offset] != expectedType || packet[offset+1] != 0 { + continue + } + if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier { + continue + } + if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq { + continue + } + return true + } + return false +} + +func CandidateICMPOffsets(packet []byte, family int) []int { + offsets := []int{0} + if len(packet) == 0 { + return offsets + } + + version := packet[0] >> 4 + if version == 4 && len(packet) >= 20 { + ihl := int(packet[0]&0x0f) * 4 + if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen { + offsets = append(offsets, ihl) + } + } else if version == 6 && len(packet) >= 40+icmpHeaderLen { + offsets = append(offsets, 40) + } + + if family == 4 && len(packet) >= 20+icmpHeaderLen { + offsets = append(offsets, 20) + } + if family == 6 && len(packet) >= 40+icmpHeaderLen { + offsets = append(offsets, 40) + } + + return DedupOffsets(offsets) +} + +func DedupOffsets(offsets []int) []int { + if len(offsets) <= 1 { + return offsets + } + seen := make(map[int]struct{}, len(offsets)) + out := make([]int, 0, len(offsets)) + for _, offset := range offsets { + if _, ok := seen[offset]; ok { + continue + } + seen[offset] = struct{}{} + out = append(out, offset) + } + return out +} + +func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { + if parsed := net.ParseIP(host); parsed != nil { + return []*net.IPAddr{{IP: parsed}}, nil + } + + var targets []*net.IPAddr + var err4 error + var err6 error + + if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil { + targets = append(targets, ip4) + } else { + err4 = err + } + + if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil { + targets = append(targets, ip6) + } else { + err6 = err + } + + if len(targets) > 0 { + return OrderTargets(targets, preferIPv4, preferIPv6), nil + } + if err4 != nil { + return nil, err4 + } + if err6 != nil { + return nil, err6 + } + return nil, nil +} + +func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { + if len(targets) <= 1 || preferIPv4 == preferIPv6 { + return targets + } + + ordered := make([]*net.IPAddr, 0, len(targets)) + if preferIPv4 { + for _, target := range targets { + if target != nil && target.IP != nil && target.IP.To4() != nil { + ordered = append(ordered, target) + } + } + for _, target := range targets { + if target != nil && target.IP != nil && target.IP.To4() == nil { + ordered = append(ordered, target) + } + } + return ordered + } + + for _, target := range targets { + if target != nil && target.IP != nil && target.IP.To4() == nil { + ordered = append(ordered, target) + } + } + for _, target := range targets { + if target != nil && target.IP != nil && target.IP.To4() != nil { + ordered = append(ordered, target) + } + } + return ordered +} diff --git a/internal/tlssniffercore/config.go b/internal/tlssniffercore/config.go new file mode 100644 index 0000000..df71c2f --- /dev/null +++ b/internal/tlssniffercore/config.go @@ -0,0 +1,123 @@ +package tlssniffercore + +import "crypto/tls" + +func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config { + if base == nil { + return selected + } + if selected == nil { + return base + } + + out := base.Clone() + ApplyServerTLSOverrides(out, selected) + return out +} + +func ApplyServerTLSOverrides(dst, src *tls.Config) { + if dst == nil || src == nil { + return + } + + if src.Rand != nil { + dst.Rand = src.Rand + } + if src.Time != nil { + dst.Time = src.Time + } + if len(src.Certificates) > 0 { + dst.Certificates = append([]tls.Certificate(nil), src.Certificates...) + } + if len(src.NameToCertificate) > 0 { + copied := make(map[string]*tls.Certificate, len(src.NameToCertificate)) + for name, cert := range src.NameToCertificate { + copied[name] = cert + } + dst.NameToCertificate = copied + } + if src.GetCertificate != nil { + dst.GetCertificate = src.GetCertificate + } + if src.GetClientCertificate != nil { + dst.GetClientCertificate = src.GetClientCertificate + } + if src.GetConfigForClient != nil { + dst.GetConfigForClient = src.GetConfigForClient + } + if src.VerifyPeerCertificate != nil { + dst.VerifyPeerCertificate = src.VerifyPeerCertificate + } + if src.VerifyConnection != nil { + dst.VerifyConnection = src.VerifyConnection + } + if src.RootCAs != nil { + dst.RootCAs = src.RootCAs + } + if len(src.NextProtos) > 0 { + dst.NextProtos = append([]string(nil), src.NextProtos...) + } + if src.ServerName != "" { + dst.ServerName = src.ServerName + } + if src.ClientAuth > dst.ClientAuth { + dst.ClientAuth = src.ClientAuth + } + if src.ClientCAs != nil { + dst.ClientCAs = src.ClientCAs + } + if src.InsecureSkipVerify { + dst.InsecureSkipVerify = true + } + if len(src.CipherSuites) > 0 { + dst.CipherSuites = append([]uint16(nil), src.CipherSuites...) + } + if src.PreferServerCipherSuites { + dst.PreferServerCipherSuites = true + } + if src.SessionTicketsDisabled { + dst.SessionTicketsDisabled = true + } + if src.SessionTicketKey != ([32]byte{}) { + dst.SessionTicketKey = src.SessionTicketKey + } + if src.ClientSessionCache != nil { + dst.ClientSessionCache = src.ClientSessionCache + } + if src.UnwrapSession != nil { + dst.UnwrapSession = src.UnwrapSession + } + if src.WrapSession != nil { + dst.WrapSession = src.WrapSession + } + if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) { + dst.MinVersion = src.MinVersion + } + if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) { + dst.MaxVersion = src.MaxVersion + } + if len(src.CurvePreferences) > 0 { + dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...) + } + if src.DynamicRecordSizingDisabled { + dst.DynamicRecordSizingDisabled = true + } + if src.Renegotiation != 0 { + dst.Renegotiation = src.Renegotiation + } + if src.KeyLogWriter != nil { + dst.KeyLogWriter = src.KeyLogWriter + } + if len(src.EncryptedClientHelloConfigList) > 0 { + dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...) + } + if src.EncryptedClientHelloRejectionVerify != nil { + dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify + } + if src.GetEncryptedClientHelloKeys != nil { + dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys + } + if len(src.EncryptedClientHelloKeys) > 0 { + dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...) + } +} diff --git a/internal/tlssniffercore/parser.go b/internal/tlssniffercore/parser.go new file mode 100644 index 0000000..6321c57 --- /dev/null +++ b/internal/tlssniffercore/parser.go @@ -0,0 +1,237 @@ +package tlssniffercore + +import ( + "bytes" + "encoding/binary" + "io" + "net" +) + +type ClientHelloMeta struct { + ServerName string + LocalAddr net.Addr + RemoteAddr net.Addr + SupportedProtos []string + SupportedVersions []uint16 + CipherSuites []uint16 +} + +type SniffResult struct { + IsTLS bool + ClientHello *ClientHelloMeta + Buffer *bytes.Buffer +} + +type Sniffer struct{} + +func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { + if maxBytes <= 0 { + maxBytes = 64 * 1024 + } + + var buf bytes.Buffer + limited := &io.LimitedReader{R: conn, N: int64(maxBytes)} + meta, isTLS := sniffClientHello(limited, &buf, conn) + + out := SniffResult{ + IsTLS: isTLS, + Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)), + } + if isTLS { + out.ClientHello = meta + } + return out, nil +} + +func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) { + meta := &ClientHelloMeta{ + LocalAddr: conn.LocalAddr(), + RemoteAddr: conn.RemoteAddr(), + } + + header, complete := readTLSRecordHeader(reader, buf) + if len(header) < 3 { + return nil, false + } + isTLS := header[0] == 0x16 && header[1] == 0x03 + if !isTLS { + return nil, false + } + if len(header) < 5 || !complete { + return meta, true + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen) + if !bodyOK { + return meta, true + } + if len(recordBody) < 4 || recordBody[0] != 0x01 { + return nil, false + } + + helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3]) + helloBytes := append([]byte(nil), recordBody[4:]...) + for len(helloBytes) < helloLen { + nextHeader, ok := readTLSRecordHeader(reader, buf) + if len(nextHeader) < 5 || !ok { + return meta, true + } + if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 { + return meta, true + } + nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5])) + nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen) + if !bodyOK { + return meta, true + } + helloBytes = append(helloBytes, nextBody...) + } + + parseClientHelloBody(meta, helloBytes[:helloLen]) + return meta, true +} + +func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) { + return readBufferedBytes(reader, buf, 5) +} + +func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) { + if count <= 0 { + return nil, true + } + tmp := make([]byte, count) + readN, err := io.ReadFull(reader, tmp) + if readN > 0 { + buf.Write(tmp[:readN]) + } + return append([]byte(nil), tmp[:readN]...), err == nil +} + +func parseClientHelloBody(meta *ClientHelloMeta, body []byte) { + if meta == nil || len(body) < 34 { + return + } + + offset := 2 + 32 + sessionIDLen := int(body[offset]) + offset++ + if offset+sessionIDLen > len(body) { + return + } + offset += sessionIDLen + + if offset+2 > len(body) { + return + } + cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) + offset += 2 + if offset+cipherSuitesLen > len(body) { + return + } + for index := 0; index+1 < cipherSuitesLen; index += 2 { + meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2])) + } + offset += cipherSuitesLen + + if offset >= len(body) { + return + } + compressionMethodsLen := int(body[offset]) + offset++ + if offset+compressionMethodsLen > len(body) { + return + } + offset += compressionMethodsLen + + if offset+2 > len(body) { + return + } + extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) + offset += 2 + if offset+extensionsLen > len(body) { + return + } + + parseClientHelloExtensions(meta, body[offset:offset+extensionsLen]) +} + +func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) { + for offset := 0; offset+4 <= len(exts); { + extType := binary.BigEndian.Uint16(exts[offset : offset+2]) + extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4])) + offset += 4 + if offset+extLen > len(exts) { + return + } + extData := exts[offset : offset+extLen] + offset += extLen + + switch extType { + case 0: + parseServerNameExtension(meta, extData) + case 16: + parseALPNExtension(meta, extData) + case 43: + parseSupportedVersionsExtension(meta, extData) + } + } +} + +func parseServerNameExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 2 { + return + } + listLen := int(binary.BigEndian.Uint16(data[:2])) + if listLen == 0 || 2+listLen > len(data) { + return + } + list := data[2 : 2+listLen] + for offset := 0; offset+3 <= len(list); { + nameType := list[offset] + nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3])) + offset += 3 + if offset+nameLen > len(list) { + return + } + if nameType == 0 { + meta.ServerName = string(list[offset : offset+nameLen]) + return + } + offset += nameLen + } +} + +func parseALPNExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 2 { + return + } + listLen := int(binary.BigEndian.Uint16(data[:2])) + if listLen == 0 || 2+listLen > len(data) { + return + } + list := data[2 : 2+listLen] + for offset := 0; offset < len(list); { + nameLen := int(list[offset]) + offset++ + if offset+nameLen > len(list) { + return + } + meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen])) + offset += nameLen + } +} + +func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 1 { + return + } + listLen := int(data[0]) + if listLen == 0 || 1+listLen > len(data) { + return + } + list := data[1 : 1+listLen] + for offset := 0; offset+1 < len(list); offset += 2 { + meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2])) + } +} diff --git a/options.go b/options.go deleted file mode 100644 index 70b29c3..0000000 --- a/options.go +++ /dev/null @@ -1,405 +0,0 @@ -package starnet - -import ( - "context" - "crypto/tls" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "time" -) - -// WithTimeout 设置请求总超时时间 -// timeout > 0: 为本次请求注入 context 超时 -// timeout = 0: 不额外设置请求总超时 -// timeout < 0: 禁用 starnet 默认总超时 -func WithTimeout(timeout time.Duration) RequestOpt { - return func(r *Request) error { - r.config.Network.Timeout = timeout - return nil - } -} - -// WithDialTimeout 设置连接超时时间 -func WithDialTimeout(timeout time.Duration) RequestOpt { - return func(r *Request) error { - r.config.Network.DialTimeout = timeout - return nil - } -} - -// WithProxy 设置代理 -func WithProxy(proxy string) RequestOpt { - return func(r *Request) error { - r.config.Network.Proxy = proxy - return nil - } -} - -// WithDialFunc 设置自定义 Dial 函数 -func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt { - return func(r *Request) error { - r.config.Network.DialFunc = fn - return nil - } -} - -// WithTLSConfig 设置 TLS 配置 -func WithTLSConfig(tlsConfig *tls.Config) RequestOpt { - return func(r *Request) error { - r.config.TLS.Config = tlsConfig - return nil - } -} - -// WithSkipTLSVerify 设置是否跳过 TLS 验证 -func WithSkipTLSVerify(skip bool) RequestOpt { - return func(r *Request) error { - r.config.TLS.SkipVerify = skip - return nil - } -} - -// WithCustomIP 设置自定义 IP -func WithCustomIP(ips []string) RequestOpt { - return func(r *Request) error { - for _, ip := range ips { - if net.ParseIP(ip) == nil { - return wrapError(ErrInvalidIP, "ip: %s", ip) - } - } - r.config.DNS.CustomIP = ips - return nil - } -} - -// WithAddCustomIP 添加自定义 IP -func WithAddCustomIP(ip string) RequestOpt { - return func(r *Request) error { - if net.ParseIP(ip) == nil { - return wrapError(ErrInvalidIP, "ip: %s", ip) - } - r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) - return nil - } -} - -// WithCustomDNS 设置自定义 DNS 服务器 -func WithCustomDNS(dnsServers []string) RequestOpt { - return func(r *Request) error { - for _, dns := range dnsServers { - if net.ParseIP(dns) == nil { - return wrapError(ErrInvalidDNS, "dns: %s", dns) - } - } - r.config.DNS.CustomDNS = dnsServers - return nil - } -} - -// WithAddCustomDNS 添加自定义 DNS 服务器 -func WithAddCustomDNS(dns string) RequestOpt { - return func(r *Request) error { - if net.ParseIP(dns) == nil { - return wrapError(ErrInvalidDNS, "dns: %s", dns) - } - r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) - return nil - } -} - -// WithLookupFunc 设置自定义 DNS 解析函数 -func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { - return func(r *Request) error { - r.config.DNS.LookupFunc = fn - return nil - } -} - -// WithHeader 设置 Header -func WithHeader(key, value string) RequestOpt { - return func(r *Request) error { - r.config.Headers.Set(key, value) - return nil - } -} - -// WithHeaders 批量设置 Headers -func WithHeaders(headers map[string]string) RequestOpt { - return func(r *Request) error { - for k, v := range headers { - r.config.Headers.Set(k, v) - } - return nil - } -} - -// WithContentType 设置 Content-Type -func WithContentType(contentType string) RequestOpt { - return func(r *Request) error { - r.config.Headers.Set("Content-Type", contentType) - return nil - } -} - -// WithUserAgent 设置 User-Agent -func WithUserAgent(userAgent string) RequestOpt { - return func(r *Request) error { - r.config.Headers.Set("User-Agent", userAgent) - return nil - } -} - -// WithBearerToken 设置 Bearer Token -func WithBearerToken(token string) RequestOpt { - return func(r *Request) error { - r.config.Headers.Set("Authorization", "Bearer "+token) - return nil - } -} - -// WithBasicAuth 设置 Basic 认证 -func WithBasicAuth(username, password string) RequestOpt { - return func(r *Request) error { - r.config.BasicAuth = [2]string{username, password} - return nil - } -} - -// WithCookie 添加 Cookie -func WithCookie(name, value, path string) RequestOpt { - return func(r *Request) error { - r.config.Cookies = append(r.config.Cookies, &http.Cookie{ - Name: name, - Value: value, - Path: path, - }) - return nil - } -} - -// WithSimpleCookie 添加简单 Cookie(path 为 /) -func WithSimpleCookie(name, value string) RequestOpt { - return func(r *Request) error { - r.config.Cookies = append(r.config.Cookies, &http.Cookie{ - Name: name, - Value: value, - Path: "/", - }) - return nil - } -} - -// WithCookies 批量添加 Cookies -func WithCookies(cookies map[string]string) RequestOpt { - return func(r *Request) error { - for name, value := range cookies { - r.config.Cookies = append(r.config.Cookies, &http.Cookie{ - Name: name, - Value: value, - Path: "/", - }) - } - return nil - } -} - -// WithBody 设置请求体(字节) -func WithBody(body []byte) RequestOpt { - return func(r *Request) error { - r.config.Body.Bytes = body - r.config.Body.Reader = nil - return nil - } -} - -// WithBodyString 设置请求体(字符串) -func WithBodyString(body string) RequestOpt { - return func(r *Request) error { - r.config.Body.Bytes = []byte(body) - r.config.Body.Reader = nil - return nil - } -} - -// WithBodyReader 设置请求体(Reader) -func WithBodyReader(reader io.Reader) RequestOpt { - return func(r *Request) error { - r.config.Body.Reader = reader - r.config.Body.Bytes = nil - return nil - } -} - -// WithJSON 设置 JSON 请求体 -func WithJSON(v interface{}) RequestOpt { - return func(r *Request) error { - data, err := json.Marshal(v) - if err != nil { - return wrapError(err, "marshal json") - } - r.config.Headers.Set("Content-Type", ContentTypeJSON) - r.config.Body.Bytes = data - r.config.Body.Reader = nil - return nil - } -} - -// WithFormData 设置表单数据 -func WithFormData(data map[string][]string) RequestOpt { - return func(r *Request) error { - r.config.Body.FormData = data - return nil - } -} - -// WithFormDataMap 设置表单数据(简化版) -func WithFormDataMap(data map[string]string) RequestOpt { - return func(r *Request) error { - for k, v := range data { - r.config.Body.FormData[k] = []string{v} - } - return nil - } -} - -// WithAddFormData 添加表单数据 -func WithAddFormData(key, value string) RequestOpt { - return func(r *Request) error { - r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) - return nil - } -} - -// WithFile 添加文件 -func WithFile(formName, filePath string) RequestOpt { - return func(r *Request) error { - stat, err := os.Stat(filePath) - if err != nil { - return wrapError(ErrFileNotFound, "file: %s", filePath) - } - - r.config.Body.Files = append(r.config.Body.Files, RequestFile{ - FormName: formName, - FileName: stat.Name(), - FilePath: filePath, - FileSize: stat.Size(), - FileType: ContentTypeOctetStream, - }) - - return nil - } -} - -// WithFileStream 添加文件流 -func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt { - return func(r *Request) error { - if reader == nil { - return ErrNilReader - } - - r.config.Body.Files = append(r.config.Body.Files, RequestFile{ - FormName: formName, - FileName: fileName, - FileData: reader, - FileSize: size, - FileType: ContentTypeOctetStream, - }) - - return nil - } -} - -// WithQuery 添加查询参数 -func WithQuery(key, value string) RequestOpt { - return func(r *Request) error { - r.config.Queries[key] = append(r.config.Queries[key], value) - return nil - } -} - -// WithQueries 批量添加查询参数 -func WithQueries(queries map[string]string) RequestOpt { - return func(r *Request) error { - for k, v := range queries { - r.config.Queries[k] = append(r.config.Queries[k], v) - } - return nil - } -} - -// WithContentLength 设置 Content-Length -func WithContentLength(length int64) RequestOpt { - return func(r *Request) error { - r.config.ContentLength = length - return nil - } -} - -// WithAutoCalcContentLength 设置是否自动计算 Content-Length -func WithAutoCalcContentLength(auto bool) RequestOpt { - return func(r *Request) error { - r.config.AutoCalcContentLength = auto - return nil - } -} - -// WithUploadProgress 设置文件上传进度回调 -func WithUploadProgress(fn UploadProgressFunc) RequestOpt { - return func(r *Request) error { - r.config.UploadProgress = fn - return nil - } -} - -// WithTransport 设置自定义 Transport -func WithTransport(transport *http.Transport) RequestOpt { - return func(r *Request) error { - r.config.Transport = transport - r.config.CustomTransport = true - return nil - } -} - -// WithAutoFetch 设置是否自动获取响应体 -func WithAutoFetch(auto bool) RequestOpt { - return func(r *Request) error { - r.autoFetch = auto - return nil - } -} - -// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) -func WithMaxRespBodyBytes(maxBytes int64) RequestOpt { - return func(r *Request) error { - if maxBytes < 0 { - return fmt.Errorf("max response body bytes must be >= 0") - } - r.config.MaxRespBodyBytes = maxBytes - return nil - } -} - -// WithRawRequest 设置原始请求 -func WithRawRequest(httpReq *http.Request) RequestOpt { - return func(r *Request) error { - if httpReq == nil { - return fmt.Errorf("httpReq cannot be nil") - } - r.httpReq = httpReq - r.doRaw = true - return nil - } -} - -// WithContext 设置 context -func WithContext(ctx context.Context) RequestOpt { - return func(r *Request) error { - r.ctx = ctx - r.httpReq = r.httpReq.WithContext(ctx) - return nil - } -} diff --git a/options_body.go b/options_body.go new file mode 100644 index 0000000..2db8959 --- /dev/null +++ b/options_body.go @@ -0,0 +1,112 @@ +package starnet + +import ( + "encoding/json" + "io" + "os" +) + +// WithBody 设置请求体(字节) +func WithBody(body []byte) RequestOpt { + return func(r *Request) error { + setBytesBodyConfig(&r.config.Body, body) + return nil + } +} + +// WithBodyString 设置请求体(字符串) +func WithBodyString(body string) RequestOpt { + return func(r *Request) error { + setBytesBodyConfig(&r.config.Body, []byte(body)) + return nil + } +} + +// WithBodyReader 设置请求体(Reader)。 +// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。 +func WithBodyReader(reader io.Reader) RequestOpt { + return func(r *Request) error { + setReaderBodyConfig(&r.config.Body, reader) + return nil + } +} + +// WithJSON 设置 JSON 请求体 +func WithJSON(v interface{}) RequestOpt { + return func(r *Request) error { + data, err := json.Marshal(v) + if err != nil { + return wrapError(err, "marshal json") + } + r.config.Headers.Set("Content-Type", ContentTypeJSON) + setBytesBodyConfig(&r.config.Body, data) + return nil + } +} + +// WithFormData 设置表单数据 +func WithFormData(data map[string][]string) RequestOpt { + return func(r *Request) error { + setFormBodyConfig(&r.config.Body, data) + return nil + } +} + +// WithFormDataMap 设置表单数据(简化版) +func WithFormDataMap(data map[string]string) RequestOpt { + return func(r *Request) error { + setFormBodyConfig(&r.config.Body, nil) + for key, value := range data { + r.config.Body.FormData[key] = []string{value} + } + return nil + } +} + +// WithAddFormData 添加表单数据 +func WithAddFormData(key, value string) RequestOpt { + return func(r *Request) error { + ensureFormMode(&r.config.Body) + r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) + return nil + } +} + +// WithFile 添加文件 +func WithFile(formName, filePath string) RequestOpt { + return func(r *Request) error { + stat, err := os.Stat(filePath) + if err != nil { + return wrapError(ErrFileNotFound, "file: %s", filePath) + } + + ensureMultipartMode(&r.config.Body) + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FilePath: filePath, + FileSize: stat.Size(), + FileType: ContentTypeOctetStream, + }) + return nil + } +} + +// WithFileStream 添加文件流 +func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt { + return func(r *Request) error { + if reader == nil { + return ErrNilReader + } + + ensureMultipartMode(&r.config.Body) + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: fileName, + FileData: reader, + FileSize: size, + FileType: ContentTypeOctetStream, + }) + return nil + } +} diff --git a/options_config.go b/options_config.go new file mode 100644 index 0000000..c0941c2 --- /dev/null +++ b/options_config.go @@ -0,0 +1,132 @@ +package starnet + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// WithTimeout 设置请求总超时时间 +// timeout > 0: 为本次请求注入 context 超时 +// timeout = 0: 不额外设置请求总超时 +// timeout < 0: 禁用 starnet 默认总超时 +func WithTimeout(timeout time.Duration) RequestOpt { + return requestOptFromMutation(mutateTimeout(timeout)) +} + +// WithDialTimeout 设置连接超时时间 +func WithDialTimeout(timeout time.Duration) RequestOpt { + return requestOptFromMutation(mutateDialTimeout(timeout)) +} + +// WithProxy 设置代理 +func WithProxy(proxy string) RequestOpt { + return requestOptFromMutation(mutateProxy(proxy)) +} + +// WithDialFunc 设置自定义 Dial 函数 +func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt { + return requestOptFromMutation(mutateDialFunc(fn)) +} + +// WithTLSConfig 设置 TLS 配置 +func WithTLSConfig(tlsConfig *tls.Config) RequestOpt { + return requestOptFromMutation(mutateTLSConfig(tlsConfig)) +} + +// WithTLSServerName 设置显式 TLS ServerName/SNI。 +func WithTLSServerName(serverName string) RequestOpt { + return requestOptFromMutation(mutateTLSServerName(serverName)) +} + +// WithTraceHooks 设置请求 trace 回调。 +func WithTraceHooks(hooks *TraceHooks) RequestOpt { + return requestOptFromMutation(mutateTraceHooks(hooks)) +} + +// WithSkipTLSVerify 设置是否跳过 TLS 验证 +func WithSkipTLSVerify(skip bool) RequestOpt { + return requestOptFromMutation(mutateSkipTLSVerify(skip)) +} + +// WithCustomIP 设置自定义 IP +func WithCustomIP(ips []string) RequestOpt { + return requestOptFromMutation(mutateCustomIP(ips)) +} + +// WithAddCustomIP 添加自定义 IP +func WithAddCustomIP(ip string) RequestOpt { + return requestOptFromMutation(mutateAddCustomIP(ip)) +} + +// WithCustomDNS 设置自定义 DNS 服务器 +func WithCustomDNS(dnsServers []string) RequestOpt { + return requestOptFromMutation(mutateCustomDNS(dnsServers)) +} + +// WithAddCustomDNS 添加自定义 DNS 服务器 +func WithAddCustomDNS(dns string) RequestOpt { + return requestOptFromMutation(mutateAddCustomDNS(dns)) +} + +// WithLookupFunc 设置自定义 DNS 解析函数 +func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { + return requestOptFromMutation(mutateLookupFunc(fn)) +} + +// WithBasicAuth 设置 Basic 认证 +func WithBasicAuth(username, password string) RequestOpt { + return requestOptFromMutation(mutateBasicAuth(username, password)) +} + +// WithQuery 添加查询参数 +func WithQuery(key, value string) RequestOpt { + return requestOptFromMutation(mutateAddQuery(key, value)) +} + +// WithQueries 批量添加查询参数 +func WithQueries(queries map[string]string) RequestOpt { + return requestOptFromMutation(mutateAddQueries(queries)) +} + +// WithContentLength 设置 Content-Length +func WithContentLength(length int64) RequestOpt { + return requestOptFromMutation(mutateContentLength(length)) +} + +// WithAutoCalcContentLength 设置是否自动计算 Content-Length +func WithAutoCalcContentLength(auto bool) RequestOpt { + return requestOptFromMutation(mutateAutoCalcContentLength(auto)) +} + +// WithUploadProgress 设置文件上传进度回调 +func WithUploadProgress(fn UploadProgressFunc) RequestOpt { + return requestOptFromMutation(mutateUploadProgress(fn)) +} + +// WithTransport 设置自定义 Transport +func WithTransport(transport *http.Transport) RequestOpt { + return requestOptFromMutation(mutateTransport(transport)) +} + +// WithAutoFetch 设置是否自动获取响应体 +func WithAutoFetch(auto bool) RequestOpt { + return requestOptFromMutation(mutateAutoFetch(auto)) +} + +// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) +func WithMaxRespBodyBytes(maxBytes int64) RequestOpt { + return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes)) +} + +// WithRawRequest 设置原始请求 +func WithRawRequest(httpReq *http.Request) RequestOpt { + return requestOptFromMutation(mutateRawRequest(httpReq)) +} + +// WithContext 设置 context +func WithContext(ctx context.Context) RequestOpt { + return requestOptFromMutation(mutateContext(ctx)) +} diff --git a/options_header.go b/options_header.go new file mode 100644 index 0000000..84a872b --- /dev/null +++ b/options_header.go @@ -0,0 +1,99 @@ +package starnet + +import "net/http" + +// WithHeader 设置 Header +func WithHeader(key, value string) RequestOpt { + return func(r *Request) error { + if isHostHeaderKey(key) { + setRequestHostConfig(r.config, value) + return nil + } + r.config.Headers.Set(key, value) + return nil + } +} + +// WithHost 设置显式 Host 头覆盖。 +func WithHost(host string) RequestOpt { + return func(r *Request) error { + setRequestHostConfig(r.config, host) + return nil + } +} + +// WithHeaders 批量设置 Headers +func WithHeaders(headers map[string]string) RequestOpt { + return func(r *Request) error { + for key, value := range headers { + if isHostHeaderKey(key) { + setRequestHostConfig(r.config, value) + continue + } + r.config.Headers.Set(key, value) + } + return nil + } +} + +// WithContentType 设置 Content-Type +func WithContentType(contentType string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("Content-Type", contentType) + return nil + } +} + +// WithUserAgent 设置 User-Agent +func WithUserAgent(userAgent string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("User-Agent", userAgent) + return nil + } +} + +// WithBearerToken 设置 Bearer Token +func WithBearerToken(token string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("Authorization", "Bearer "+token) + return nil + } +} + +// WithCookie 添加 Cookie +func WithCookie(name, value, path string) RequestOpt { + return func(r *Request) error { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: path, + }) + return nil + } +} + +// WithSimpleCookie 添加简单 Cookie(path 为 /) +func WithSimpleCookie(name, value string) RequestOpt { + return func(r *Request) error { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) + return nil + } +} + +// WithCookies 批量添加 Cookies +func WithCookies(cookies map[string]string) RequestOpt { + return func(r *Request) error { + for name, value := range cookies { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) + } + return nil + } +} diff --git a/ping.go b/ping.go index b31be44..28536ed 100644 --- a/ping.go +++ b/ping.go @@ -2,14 +2,14 @@ package starnet import ( "context" - "encoding/binary" "errors" "fmt" "net" "os" "strings" - "sync/atomic" "time" + + "b612.me/starnet/internal/pingcore" ) const ( @@ -18,7 +18,6 @@ const ( icmpTypeEchoRequestV6 = 128 icmpTypeEchoReplyV6 = 129 - icmpHeaderLen = 8 icmpReadBufSz = 1500 defaultPingAttemptTimeout = 2 * time.Second @@ -26,13 +25,7 @@ const ( maxPingPayloadSize = 65499 // 65507 - ICMP header(8) ) -type ICMP struct { - Type uint8 - Code uint8 - CheckSum uint16 - Identifier uint16 - SequenceNum uint16 -} +type ICMP = pingcore.ICMP type pingSocketSpec struct { network string @@ -42,53 +35,20 @@ type pingSocketSpec struct { } // PingOptions controls ping probing behavior. -type PingOptions struct { - Count int // ping attempts for Pingable, default 3 - Timeout time.Duration // per-attempt timeout, default 2s - Interval time.Duration // delay between attempts, default 0 - Deadline time.Time // overall deadline for Pingable/PingWithContext - PreferIPv4 bool // prefer IPv4 targets - PreferIPv6 bool // prefer IPv6 targets - SourceIP net.IP // optional source IP for raw socket bind - PayloadSize int // ICMP payload bytes, default 0 -} +type PingOptions = pingcore.Options -type PingResult struct { - Duration time.Duration - RecvCount int - RemoteIP string -} - -var pingIdentifierSeed uint32 +type PingResult = pingcore.Result func nextPingIdentifier() uint16 { - pid := uint32(os.Getpid() & 0xffff) - n := atomic.AddUint32(&pingIdentifierSeed, 1) - return uint16((pid + n) & 0xffff) + return pingcore.NextIdentifier() } func pingPayload(size int) []byte { - if size <= 0 { - return nil - } - payload := make([]byte, size) - for i := 0; i < len(payload); i++ { - payload[i] = byte(i) - } - return payload + return pingcore.Payload(size) } func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { - icmp := ICMP{ - Type: typ, - Code: 0, - CheckSum: 0, - Identifier: identifier, - SequenceNum: seq, - } - buf := marshalICMPPacket(icmp, payload) - icmp.CheckSum = checkSum(buf) - return icmp + return pingcore.BuildICMP(seq, identifier, typ, payload) } func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) { @@ -120,8 +80,8 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n return res, wrapError(err, "ping write request") } - tStart := time.Now() - deadline := tStart.Add(timeout) + startedAt := time.Now() + deadline := startedAt.Add(timeout) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { deadline = d } @@ -150,108 +110,34 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n } if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) { res.RecvCount = n - res.Duration = time.Since(tStart) + res.Duration = time.Since(startedAt) return res, nil } } } func checkSum(data []byte) uint16 { - var ( - sum uint32 - length int = len(data) - index int - ) - for length > 1 { - sum += uint32(data[index])<<8 + uint32(data[index+1]) - index += 2 - length -= 2 - } - if length > 0 { - sum += uint32(data[index]) << 8 - } - for sum>>16 != 0 { - sum = (sum & 0xffff) + (sum >> 16) - } - - return uint16(^sum) + return pingcore.Checksum(data) } func marshalICMP(icmp ICMP) []byte { - return marshalICMPPacket(icmp, nil) + return pingcore.Marshal(icmp) } func marshalICMPPacket(icmp ICMP, payload []byte) []byte { - buf := make([]byte, icmpHeaderLen+len(payload)) - buf[0] = icmp.Type - buf[1] = icmp.Code - binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum) - binary.BigEndian.PutUint16(buf[4:], icmp.Identifier) - binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum) - copy(buf[icmpHeaderLen:], payload) - return buf + return pingcore.MarshalPacket(icmp, payload) } func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { - for _, off := range candidateICMPOffsets(packet, family) { - if off < 0 || off+icmpHeaderLen > len(packet) { - continue - } - if packet[off] != expectedType || packet[off+1] != 0 { - continue - } - if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier { - continue - } - if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq { - continue - } - return true - } - return false + return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq) } func candidateICMPOffsets(packet []byte, family int) []int { - offsets := []int{0} - if len(packet) == 0 { - return offsets - } - - ver := packet[0] >> 4 - if ver == 4 && len(packet) >= 20 { - ihl := int(packet[0]&0x0f) * 4 - if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen { - offsets = append(offsets, ihl) - } - } else if ver == 6 && len(packet) >= 40+icmpHeaderLen { - offsets = append(offsets, 40) - } - - // 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。 - if family == 4 && len(packet) >= 20+icmpHeaderLen { - offsets = append(offsets, 20) - } - if family == 6 && len(packet) >= 40+icmpHeaderLen { - offsets = append(offsets, 40) - } - - return dedupOffsets(offsets) + return pingcore.CandidateICMPOffsets(packet, family) } func dedupOffsets(offsets []int) []int { - if len(offsets) <= 1 { - return offsets - } - m := make(map[int]struct{}, len(offsets)) - out := make([]int, 0, len(offsets)) - for _, off := range offsets { - if _, ok := m[off]; ok { - continue - } - m[off] = struct{}{} - out = append(out, off) - } - return out + return pingcore.DedupOffsets(offsets) } func socketSpecForIP(ip net.IP) (pingSocketSpec, error) { @@ -297,70 +183,18 @@ func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) { } func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { - if parsed := net.ParseIP(host); parsed != nil { - return []*net.IPAddr{{IP: parsed}}, nil + targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6) + if err != nil { + return nil, err } - - var targets []*net.IPAddr - var err4 error - var err6 error - - if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil { - targets = append(targets, ip4) - } else { - err4 = e + if len(targets) == 0 { + return nil, ErrPingNoResolvedTarget } - - if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil { - targets = append(targets, ip6) - } else { - err6 = e - } - - if len(targets) > 0 { - return orderPingTargets(targets, preferIPv4, preferIPv6), nil - } - - if err4 != nil { - return nil, err4 - } - if err6 != nil { - return nil, err6 - } - return nil, ErrPingNoResolvedTarget + return targets, nil } func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { - if len(targets) <= 1 || preferIPv4 == preferIPv6 { - return targets - } - - ordered := make([]*net.IPAddr, 0, len(targets)) - if preferIPv4 { - for _, t := range targets { - if t != nil && t.IP != nil && t.IP.To4() != nil { - ordered = append(ordered, t) - } - } - for _, t := range targets { - if t != nil && t.IP != nil && t.IP.To4() == nil { - ordered = append(ordered, t) - } - } - return ordered - } - - for _, t := range targets { - if t != nil && t.IP != nil && t.IP.To4() == nil { - ordered = append(ordered, t) - } - } - for _, t := range targets { - if t != nil && t.IP != nil && t.IP.To4() != nil { - ordered = append(ordered, t) - } - } - return ordered + return pingcore.OrderTargets(targets, preferIPv4, preferIPv6) } func normalizePingDialError(err error) error { @@ -450,7 +284,6 @@ func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOpt return resp, nil } - // 权限问题通常与地址族无关,继续重试意义不大。 if errors.Is(err, ErrPingPermissionDenied) { return res, err } @@ -501,8 +334,8 @@ func Pingable(host string, opts *PingOptions) (bool, error) { } var lastErr error - for i := 0; i < cfg.Count; i++ { - _, err := pingOnceWithOptions(ctx, host, 29+i, cfg) + for index := 0; index < cfg.Count; index++ { + _, err := pingOnceWithOptions(ctx, host, 29+index, cfg) if err == nil { return true, nil } @@ -512,7 +345,7 @@ func Pingable(host string, opts *PingOptions) (bool, error) { break } - if i < cfg.Count-1 && cfg.Interval > 0 { + if index < cfg.Count-1 && cfg.Interval > 0 { timer := time.NewTimer(cfg.Interval) select { case <-ctx.Done(): diff --git a/proxy_custom_ip_test.go b/proxy_custom_ip_test.go new file mode 100644 index 0000000..0f432a2 --- /dev/null +++ b/proxy_custom_ip_test.go @@ -0,0 +1,110 @@ +package starnet + +import ( + "fmt" + "net" + "net/http" + "testing" +) + +func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) { + tlsReqInfo := make(chan struct { + host string + sni string + }, 1) + tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tlsReqInfo <- struct { + host string + sni string + }{ + host: r.Host, + sni: r.TLS.ServerName, + } + _, _ = w.Write([]byte("ok")) + })) + defer tlsServer.Close() + + _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) + if err != nil { + t.Fatalf("split tls server addr: %v", err) + } + + proxyServer := newIPv4ConnectProxyServer(t, nil) + defer proxyServer.Close() + + targetHost := "proxy-custom-ip.test" + reqURL := fmt.Sprintf("https://%s:%s", targetHost, port) + req := NewSimpleRequest(reqURL, http.MethodGet). + SetProxy(proxyServer.URL). + SetCustomIP([]string{"127.0.0.1"}). + SetSkipTLSVerify(true) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + + targets := proxyServer.Targets() + if len(targets) != 1 { + t.Fatalf("connect targets=%v; want 1 target", targets) + } + gotConnectTarget := targets[0] + wantConnectTarget := net.JoinHostPort("127.0.0.1", port) + if gotConnectTarget != wantConnectTarget { + t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget) + } + + gotTLS := <-tlsReqInfo + wantHost := net.JoinHostPort(targetHost, port) + if gotTLS.host != wantHost { + t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost) + } + if gotTLS.sni != targetHost { + t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost) + } +} + +func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) { + tlsReqInfo := make(chan struct { + host string + sni string + }, 1) + tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tlsReqInfo <- struct { + host string + sni string + }{ + host: r.Host, + sni: r.TLS.ServerName, + } + _, _ = w.Write([]byte("ok")) + })) + defer tlsServer.Close() + + _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) + if err != nil { + t.Fatalf("split tls server addr: %v", err) + } + + targetHost := "custom-ip-direct.test" + reqURL := fmt.Sprintf("https://%s:%s", targetHost, port) + req := NewSimpleRequest(reqURL, http.MethodGet). + SetCustomIP([]string{"127.0.0.1"}). + SetSkipTLSVerify(true) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + + gotTLS := <-tlsReqInfo + wantHost := net.JoinHostPort(targetHost, port) + if gotTLS.host != wantHost { + t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost) + } + if gotTLS.sni != targetHost { + t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost) + } +} diff --git a/proxy_local_helpers_test.go b/proxy_local_helpers_test.go new file mode 100644 index 0000000..ab753d1 --- /dev/null +++ b/proxy_local_helpers_test.go @@ -0,0 +1,331 @@ +package starnet + +import ( + "crypto/tls" + "crypto/x509" + "encoding/binary" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +type connectProxyServer struct { + *httptest.Server + mu sync.Mutex + targets []string +} + +func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server { + t.Helper() + + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp4: %v", err) + } + + server := httptest.NewUnstartedServer(handler) + server.Listener = listener + server.Start() + return server +} + +func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server { + t.Helper() + + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp4: %v", err) + } + + server := httptest.NewUnstartedServer(handler) + server.Listener = listener + server.StartTLS() + return server +} + +func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) { + t.Helper() + + testT, ok := t.(*testing.T) + if !ok { + t.Fatal("newTrustedIPv4TLSServer requires *testing.T") + } + + certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName) + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair: %v", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(certPEM) { + t.Fatal("AppendCertsFromPEM returned false") + } + + server := httptest.NewUnstartedServer(handler) + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp4: %v", err) + } + server.Listener = listener + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + server.StartTLS() + return server, pool +} + +func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string { + t.Helper() + + _, port, err := net.SplitHostPort(server.Listener.Addr().String()) + if err != nil { + t.Fatalf("split host port: %v", err) + } + return fmt.Sprintf("https://%s:%s", host, port) +} + +func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer { + t.Helper() + + proxy := &connectProxyServer{} + if dialTarget == nil { + dialTarget = func(target string) (net.Conn, error) { + return net.Dial("tcp", target) + } + } + + proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "connect required", http.StatusMethodNotAllowed) + return + } + + proxy.mu.Lock() + proxy.targets = append(proxy.targets, r.Host) + proxy.mu.Unlock() + + targetConn, err := dialTarget(r.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + targetConn.Close() + t.Fatal("proxy response writer is not a hijacker") + } + + clientConn, rw, err := hijacker.Hijack() + if err != nil { + targetConn.Close() + t.Fatalf("hijack proxy conn: %v", err) + } + if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { + clientConn.Close() + targetConn.Close() + t.Fatalf("write connect response: %v", err) + } + if err := rw.Flush(); err != nil { + clientConn.Close() + targetConn.Close() + t.Fatalf("flush connect response: %v", err) + } + + relayProxyConns(clientConn, targetConn) + })) + return proxy +} + +func (p *connectProxyServer) Targets() []string { + p.mu.Lock() + defer p.mu.Unlock() + return append([]string(nil), p.targets...) +} + +type socks5ProxyServer struct { + ln net.Listener + addr string + dial func(target string) (net.Conn, error) + stopCh chan struct{} + wg sync.WaitGroup + mu sync.Mutex + targets []string +} + +func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer { + t.Helper() + + if dialTarget == nil { + dialTarget = func(target string) (net.Conn, error) { + return net.Dial("tcp", target) + } + } + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp4 socks5: %v", err) + } + + proxy := &socks5ProxyServer{ + ln: ln, + addr: ln.Addr().String(), + dial: dialTarget, + stopCh: make(chan struct{}), + } + + proxy.wg.Add(1) + go func() { + defer proxy.wg.Done() + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-proxy.stopCh: + return + default: + return + } + } + + proxy.wg.Add(1) + go func(c net.Conn) { + defer proxy.wg.Done() + proxy.handleConn(t, c) + }(conn) + } + }() + + return proxy +} + +func (p *socks5ProxyServer) URL() string { + return "socks5://" + p.addr +} + +func (p *socks5ProxyServer) Targets() []string { + p.mu.Lock() + defer p.mu.Unlock() + return append([]string(nil), p.targets...) +} + +func (p *socks5ProxyServer) Close() { + close(p.stopCh) + _ = p.ln.Close() + p.wg.Wait() +} + +func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) { + t.Helper() + closeConn := true + defer func() { + if closeConn { + _ = conn.Close() + } + }() + + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + return + } + if header[0] != 0x05 { + return + } + + methods := make([]byte, int(header[1])) + if _, err := io.ReadFull(conn, methods); err != nil { + return + } + if _, err := conn.Write([]byte{0x05, 0x00}); err != nil { + return + } + + reqHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, reqHeader); err != nil { + return + } + if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 { + _, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + host, err := readSOCKS5Addr(conn, reqHeader[3]) + if err != nil { + _, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + portBytes := make([]byte, 2) + if _, err := io.ReadFull(conn, portBytes); err != nil { + return + } + target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes))) + + p.mu.Lock() + p.targets = append(p.targets, target) + p.mu.Unlock() + + targetConn, err := p.dial(target) + if err != nil { + _, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil { + targetConn.Close() + return + } + + closeConn = false + relayProxyConns(conn, targetConn) +} + +func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) { + switch atyp { + case 0x01: + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + return net.IP(buf).String(), nil + case 0x03: + var size [1]byte + if _, err := io.ReadFull(r, size[:]); err != nil { + return "", err + } + buf := make([]byte, int(size[0])) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + return string(buf), nil + case 0x04: + buf := make([]byte, 16) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + return net.IP(buf).String(), nil + default: + return "", fmt.Errorf("unsupported atyp: %d", atyp) + } +} + +func relayProxyConns(left, right net.Conn) { + var once sync.Once + closeBoth := func() { + _ = left.Close() + _ = right.Close() + } + + go func() { + _, _ = io.Copy(left, right) + once.Do(closeBoth) + }() + go func() { + _, _ = io.Copy(right, left) + once.Do(closeBoth) + }() +} diff --git a/request.go b/request.go index 9484cd5..a4a25e7 100644 --- a/request.go +++ b/request.go @@ -22,14 +22,131 @@ type Request struct { httpClient *http.Client httpReq *http.Request retry *retryPolicy + traceHooks *TraceHooks + traceState *traceState applied bool // 是否已应用配置 doRaw bool // 是否使用原始请求(不修改) autoFetch bool // 是否自动获取响应体 + + rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request + rawTemplate *http.Request +} + +func normalizeContext(ctx context.Context) context.Context { + if ctx != nil { + return ctx + } + return context.Background() +} + +func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) { + if httpReq == nil { + return nil, fmt.Errorf("http request is nil") + } + + cloned := httpReq.Clone(normalizeContext(ctx)) + switch { + case httpReq.Body == nil || httpReq.Body == http.NoBody: + cloned.Body = httpReq.Body + case httpReq.GetBody != nil: + body, err := httpReq.GetBody() + if err != nil { + return cloned, wrapError(err, "clone raw request body") + } + cloned.Body = body + default: + return cloned, fmt.Errorf("cannot clone raw request with non-replayable body") + } + + return cloned, nil +} + +func (r *Request) rawBaseRequest() *http.Request { + if r == nil { + return nil + } + if r.rawTemplate != nil { + return r.rawTemplate + } + return r.httpReq +} + +func (r *Request) invalidatePreparedState() { + if r == nil { + return + } + if r.cancel != nil { + r.cancel() + r.cancel = nil + } + r.execCtx = nil + r.traceState = nil + r.httpClient = nil + + wasApplied := r.applied + r.applied = false + if !wasApplied || r.doRaw { + return + } + if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil { + r.err = err + } +} + +func (r *Request) rebuildPreparedRequestBase() error { + if r == nil || r.doRaw { + return nil + } + ctx := r.ctx + if ctx == nil { + ctx = context.Background() + } + httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil) + if err != nil { + return wrapError(err, "rebuild http request") + } + r.httpReq = httpReq + r.syncRequestHost() + return nil +} + +func (r *Request) rebuildRawRequestBase() error { + if r == nil || !r.doRaw { + return nil + } + baseReq := r.rawBaseRequest() + rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx)) + if err != nil && baseReq != nil && baseReq == r.httpReq { + r.httpReq = baseReq.WithContext(normalizeContext(r.ctx)) + return nil + } + if rawReq != nil { + r.httpReq = rawReq + } + return err +} + +func (r *Request) rebuildExecutionRequestBase() error { + if r == nil { + return nil + } + if r.cancel != nil { + r.cancel() + r.cancel = nil + } + r.execCtx = nil + r.traceState = nil + r.applied = false + if r.doRaw { + return r.rebuildRawRequestBase() + } + return r.rebuildPreparedRequestBase() } // newRequest 创建新请求(内部使用) func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) { + ctx = normalizeContext(ctx) if method == "" { method = http.MethodGet } @@ -133,6 +250,7 @@ func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { // NewSimpleRequestWithContext 创建新请求(带 context,忽略错误) func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { + ctx = normalizeContext(ctx) req, err := newRequest(ctx, url, method, opts...) if err != nil { return &Request{ @@ -163,16 +281,24 @@ func (r *Request) Clone() *Request { client: r.client, httpClient: r.httpClient, retry: cloneRetryPolicy(r.retry), + traceHooks: r.traceHooks, applied: false, // 重置应用状态 doRaw: r.doRaw, autoFetch: r.autoFetch, + + rawSourceExternal: r.rawSourceExternal, } // 重新创建 http.Request if !r.doRaw { cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil) } else { - cloned.httpReq = r.httpReq + rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx) + cloned.rawTemplate = rawTemplate + cloned.httpReq = rawTemplate + if err != nil && cloned.err == nil { + cloned.err = err + } } return cloned @@ -190,12 +316,7 @@ func (r *Request) Context() context.Context { // SetContext 设置 context func (r *Request) SetContext(ctx context.Context) *Request { - if r.err != nil { - return r - } - r.ctx = ctx - r.httpReq = r.httpReq.WithContext(ctx) - return r + return r.applyMutation(mutateContext(ctx)) } // Method 获取 HTTP 方法 @@ -215,7 +336,13 @@ func (r *Request) SetMethod(method string) *Request { } r.method = method - r.httpReq.Method = method + if r.httpReq != nil { + r.httpReq.Method = method + } + if r.doRaw && r.rawTemplate != nil { + r.rawTemplate.Method = method + } + r.invalidatePreparedState() return r } @@ -243,45 +370,74 @@ func (r *Request) SetURL(urlStr string) *Request { r.url = urlStr u.Host = removeEmptyPort(u.Host) - r.httpReq.Host = u.Host r.httpReq.URL = u - - // 更新 TLS ServerName - if r.config.TLS.Config != nil { - r.config.TLS.Config.ServerName = u.Hostname() - } + r.syncRequestHost() + r.invalidatePreparedState() return r } +func (r *Request) effectiveRequestHost() string { + if r == nil { + return "" + } + if r.config != nil && r.config.Host != "" { + return r.config.Host + } + if r.httpReq != nil && r.httpReq.URL != nil { + return removeEmptyPort(r.httpReq.URL.Host) + } + if r.url == "" { + return "" + } + u, err := url.Parse(r.url) + if err != nil { + return "" + } + return removeEmptyPort(u.Host) +} + +func (r *Request) syncRequestHost() { + if r == nil || r.httpReq == nil { + return + } + r.httpReq.Host = r.effectiveRequestHost() +} + // RawRequest 获取底层 http.Request func (r *Request) RawRequest() *http.Request { + if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied { + return r.rawTemplate + } return r.httpReq } // SetRawRequest 设置底层 http.Request(启用原始模式) func (r *Request) SetRawRequest(httpReq *http.Request) *Request { - if r.err != nil { - return r - } - r.httpReq = httpReq - r.doRaw = true - if httpReq == nil { - r.err = fmt.Errorf("httpReq cannot be nil") - return r - } - return r + return r.applyMutation(mutateRawRequest(httpReq)) } // EnableRawMode 启用原始模式(不修改请求) func (r *Request) EnableRawMode() *Request { + if r.doRaw { + return r + } r.doRaw = true + r.invalidatePreparedState() return r } // DisableRawMode 禁用原始模式 func (r *Request) DisableRawMode() *Request { + if !r.doRaw { + return r + } + if r.rawSourceExternal { + r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest") + return r + } r.doRaw = false + r.invalidatePreparedState() return r } @@ -329,6 +485,10 @@ func (r *Request) Do() (*Response, error) { } func (r *Request) doOnce() (*Response, error) { + if err := r.rebuildExecutionRequestBase(); err != nil { + return nil, wrapError(err, "rebuild execution request") + } + // 准备请求 if err := r.prepare(); err != nil { return nil, wrapError(err, "prepare request") diff --git a/request_body.go b/request_body.go index 63e5ea3..fbd211a 100644 --- a/request_body.go +++ b/request_body.go @@ -1,16 +1,9 @@ package starnet import ( - "bytes" - "context" "encoding/json" - "fmt" "io" - "mime/multipart" - "net/http" - "net/url" "os" - "strings" ) // SetBody 设置请求体(字节) @@ -21,12 +14,13 @@ func (r *Request) SetBody(body []byte) *Request { if r.doRaw { return r } - r.config.Body.Bytes = body - r.config.Body.Reader = nil + setBytesBodyConfig(&r.config.Body, body) + r.invalidatePreparedState() return r } -// SetBodyReader 设置请求体(Reader) +// SetBodyReader 设置请求体(Reader)。 +// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。 func (r *Request) SetBodyReader(reader io.Reader) *Request { if r.err != nil { return r @@ -34,8 +28,8 @@ func (r *Request) SetBodyReader(reader io.Reader) *Request { if r.doRaw { return r } - r.config.Body.Reader = reader - r.config.Body.Bytes = nil + setReaderBodyConfig(&r.config.Body, reader) + r.invalidatePreparedState() return r } @@ -67,7 +61,8 @@ func (r *Request) SetFormData(data map[string][]string) *Request { if r.doRaw { return r } - r.config.Body.FormData = cloneStringMapSlice(data) + setFormBodyConfig(&r.config.Body, data) + r.invalidatePreparedState() return r } @@ -79,7 +74,9 @@ func (r *Request) AddFormData(key, value string) *Request { if r.doRaw { return r } + ensureFormMode(&r.config.Body) r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) + r.invalidatePreparedState() return r } @@ -91,9 +88,11 @@ func (r *Request) AddFormDataMap(data map[string]string) *Request { if r.doRaw { return r } - for k, v := range data { - r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v) + ensureFormMode(&r.config.Body) + for key, value := range data { + r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) } + r.invalidatePreparedState() return r } @@ -109,6 +108,7 @@ func (r *Request) AddFile(formName, filePath string) *Request { return r } + ensureMultipartMode(&r.config.Body) r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: stat.Name(), @@ -116,6 +116,7 @@ func (r *Request) AddFile(formName, filePath string) *Request { FileSize: stat.Size(), FileType: ContentTypeOctetStream, }) + r.invalidatePreparedState() return r } @@ -132,6 +133,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request return r } + ensureMultipartMode(&r.config.Body) r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, @@ -139,6 +141,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request FileSize: stat.Size(), FileType: ContentTypeOctetStream, }) + r.invalidatePreparedState() return r } @@ -155,6 +158,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request return r } + ensureMultipartMode(&r.config.Body) r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: stat.Name(), @@ -162,6 +166,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request FileSize: stat.Size(), FileType: fileType, }) + r.invalidatePreparedState() return r } @@ -177,6 +182,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io return r } + ensureMultipartMode(&r.config.Body) r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, @@ -184,6 +190,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io FileSize: size, FileType: ContentTypeOctetStream, }) + r.invalidatePreparedState() return r } @@ -199,6 +206,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz return r } + ensureMultipartMode(&r.config.Body) r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, @@ -206,243 +214,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz FileSize: size, FileType: fileType, }) + r.invalidatePreparedState() return r } - -// applyBody 应用请求体 -func (r *Request) applyBody() error { - // 优先级:Reader > Bytes > Files > FormData - - // 1. Reader - if r.config.Body.Reader != nil { - r.httpReq.Body = io.NopCloser(r.config.Body.Reader) - - // 尝试获取长度 - switch v := r.config.Body.Reader.(type) { - case *bytes.Buffer: - r.httpReq.ContentLength = int64(v.Len()) - case *bytes.Reader: - r.httpReq.ContentLength = int64(v.Len()) - case *strings.Reader: - r.httpReq.ContentLength = int64(v.Len()) - } - - return nil - } - - // 2. Bytes - if len(r.config.Body.Bytes) > 0 { - r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes)) - r.httpReq.ContentLength = int64(len(r.config.Body.Bytes)) - return nil - } - - // 3. Files(multipart/form-data) - if len(r.config.Body.Files) > 0 { - return r.applyMultipartBody() - } - - // 4. FormData(application/x-www-form-urlencoded) - if len(r.config.Body.FormData) > 0 { - values := url.Values{} - for k, vs := range r.config.Body.FormData { - for _, v := range vs { - values.Add(k, v) - } - } - encoded := values.Encode() - r.httpReq.Body = io.NopCloser(strings.NewReader(encoded)) - r.httpReq.ContentLength = int64(len(encoded)) - return nil - } - - return nil -} - -// applyMultipartBody 应用 multipart 请求体 -func (r *Request) applyMultipartBody() error { - pr, pw := io.Pipe() - writer := multipart.NewWriter(pw) - - // 设置 Content-Type - r.httpReq.Header.Set("Content-Type", writer.FormDataContentType()) - r.httpReq.Body = pr - - // 在 goroutine 中写入数据 - go func() { - defer pw.Close() - defer writer.Close() - - // 写入表单字段 - for k, vs := range r.config.Body.FormData { - for _, v := range vs { - if err := writer.WriteField(k, v); err != nil { - pw.CloseWithError(wrapError(err, "write form field")) - return - } - } - } - - // 写入文件 - for _, file := range r.config.Body.Files { - if err := r.writeFile(writer, file); err != nil { - pw.CloseWithError(err) - return - } - } - }() - - return nil -} - -// writeFile 写入文件到 multipart writer -func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error { - // 创建文件字段 - part, err := writer.CreateFormFile(file.FormName, file.FileName) - if err != nil { - return wrapError(err, "create form file") - } - - // 获取文件数据源 - var reader io.Reader - if file.FileData != nil { - reader = file.FileData - } else if file.FilePath != "" { - f, err := os.Open(file.FilePath) - if err != nil { - return wrapError(err, "open file") - } - defer f.Close() - reader = f - } else { - return ErrNilReader - } - - // 复制文件数据(带进度) - if r.config.UploadProgress != nil { - _, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress) - } else { - _, err = io.Copy(part, reader) - } - - if err != nil { - return wrapError(err, "copy file data") - } - - return nil -} - -// prepare 准备请求(应用配置) -func (r *Request) prepare() error { - if r.applied { - return nil - } - - // 即使 raw 模式也要确保有 httpClient - if r.httpClient == nil { - var err error - r.httpClient, err = r.buildHTTPClient() - if err != nil { - return err // ← 失败时不设置 applied - } - } - - if r.httpReq == nil { - return fmt.Errorf("http request is nil") - } - - // 原始模式不修改请求内容 - if !r.doRaw { - // 应用查询参数 - if len(r.config.Queries) > 0 { - q := r.httpReq.URL.Query() - for k, values := range r.config.Queries { - for _, v := range values { - q.Add(k, v) - } - } - r.httpReq.URL.RawQuery = q.Encode() - } - - // 应用 Headers - for k, values := range r.config.Headers { - for _, v := range values { - r.httpReq.Header.Add(k, v) - } - } - - // 应用 Cookies - for _, cookie := range r.config.Cookies { - r.httpReq.AddCookie(cookie) - } - - // 应用 Basic Auth - if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { - r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) - } - - // 应用请求体 - if err := r.applyBody(); err != nil { - return err - } - - // 应用 Content-Length - if r.config.ContentLength > 0 { - r.httpReq.ContentLength = r.config.ContentLength - } else if r.config.ContentLength < 0 { - r.httpReq.ContentLength = 0 - } - - // 自动计算 Content-Length - if r.config.AutoCalcContentLength && r.httpReq.Body != nil { - data, err := io.ReadAll(r.httpReq.Body) - if err != nil { - return wrapError(err, "read body for content length") - } - r.httpReq.ContentLength = int64(len(data)) - r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data)) - } - - // 设置 TLS ServerName(如果有 TLS Config) - if r.config.TLS.Config != nil && r.httpReq.URL != nil { - r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() - } - } - - execCtx := r.ctx - if !r.doRaw { - // raw 模式下不注入请求级网络配置,只应用 context/超时。 - execCtx = injectRequestConfig(execCtx, r.config) - } - - // 请求级总超时通过 context 控制,避免污染共享 http.Client。 - if r.config.Network.Timeout > 0 { - execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout) - } - - r.execCtx = execCtx - r.httpReq = r.httpReq.WithContext(r.execCtx) - - r.applied = true - return nil -} - -// buildHTTPClient 构建 HTTP Client -func (r *Request) buildHTTPClient() (*http.Client, error) { - // 优先使用请求关联的 Client - if r.client != nil { - return r.client.HTTPClient(), nil - } - - // 自定义 Transport - if r.config.CustomTransport && r.config.Transport != nil { - return &http.Client{ - Transport: &Transport{base: r.config.Transport}, - Timeout: 0, - }, nil - } - - // 默认全局 client - return DefaultHTTPClient(), nil -} diff --git a/request_config.go b/request_config.go deleted file mode 100644 index 195ef22..0000000 --- a/request_config.go +++ /dev/null @@ -1,282 +0,0 @@ -package starnet - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "time" -) - -// SetTimeout 设置请求总超时时间 -// timeout > 0: 为本次请求注入 context 超时 -// timeout = 0: 不额外设置请求总超时 -// timeout < 0: 禁用 starnet 默认总超时 -func (r *Request) SetTimeout(timeout time.Duration) *Request { - if r.err != nil { - return r - } - r.config.Network.Timeout = timeout - return r -} - -// SetDialTimeout 设置连接超时时间 -func (r *Request) SetDialTimeout(timeout time.Duration) *Request { - if r.err != nil { - return r - } - r.config.Network.DialTimeout = timeout - return r -} - -// SetProxy 设置代理 -func (r *Request) SetProxy(proxy string) *Request { - if r.err != nil { - return r - } - r.config.Network.Proxy = proxy - return r -} - -// SetDialFunc 设置自定义 Dial 函数 -func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request { - if r.err != nil { - return r - } - r.config.Network.DialFunc = fn - return r -} - -// SetTLSConfig 设置 TLS 配置 -func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request { - if r.err != nil { - return r - } - r.config.TLS.Config = tlsConfig - return r -} - -// SetSkipTLSVerify 设置是否跳过 TLS 验证 -func (r *Request) SetSkipTLSVerify(skip bool) *Request { - if r.err != nil { - return r - } - r.config.TLS.SkipVerify = skip - return r -} - -// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS) -func (r *Request) SetCustomIP(ips []string) *Request { - if r.err != nil { - return r - } - - // 验证 IP 格式 - for _, ip := range ips { - if net.ParseIP(ip) == nil { - r.err = wrapError(ErrInvalidIP, "ip: %s", ip) - return r - } - } - - r.config.DNS.CustomIP = ips - return r -} - -// AddCustomIP 添加自定义 IP -func (r *Request) AddCustomIP(ip string) *Request { - if r.err != nil { - return r - } - - if net.ParseIP(ip) == nil { - r.err = wrapError(ErrInvalidIP, "ip: %s", ip) - return r - } - - r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) - return r -} - -// SetCustomDNS 设置自定义 DNS 服务器 -func (r *Request) SetCustomDNS(dnsServers []string) *Request { - if r.err != nil { - return r - } - - // 验证 DNS 服务器格式 - for _, dns := range dnsServers { - if net.ParseIP(dns) == nil { - r.err = wrapError(ErrInvalidDNS, "dns: %s", dns) - return r - } - } - - r.config.DNS.CustomDNS = dnsServers - return r -} - -// AddCustomDNS 添加自定义 DNS 服务器 -func (r *Request) AddCustomDNS(dns string) *Request { - if r.err != nil { - return r - } - - if net.ParseIP(dns) == nil { - r.err = wrapError(ErrInvalidDNS, "dns: %s", dns) - return r - } - - r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) - return r -} - -// SetLookupFunc 设置自定义 DNS 解析函数 -func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request { - if r.err != nil { - return r - } - r.config.DNS.LookupFunc = fn - return r -} - -// SetBasicAuth 设置 Basic 认证 -func (r *Request) SetBasicAuth(username, password string) *Request { - if r.err != nil { - return r - } - r.config.BasicAuth = [2]string{username, password} - return r -} - -// SetContentLength 设置 Content-Length -func (r *Request) SetContentLength(length int64) *Request { - if r.err != nil { - return r - } - r.config.ContentLength = length - return r -} - -// SetAutoCalcContentLength 设置是否自动计算 Content-Length -// 警告:启用后会将整个 body 读入内存 -func (r *Request) SetAutoCalcContentLength(auto bool) *Request { - if r.err != nil { - return r - } - - if r.doRaw { - r.err = fmt.Errorf("cannot set auto calc content length in raw mode") - return r - } - - r.config.AutoCalcContentLength = auto - return r -} - -// SetTransport 设置自定义 Transport -func (r *Request) SetTransport(transport *http.Transport) *Request { - if r.err != nil { - return r - } - r.config.Transport = transport - r.config.CustomTransport = true - return r -} - -// SetUploadProgress 设置文件上传进度回调 -func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request { - if r.err != nil { - return r - } - r.config.UploadProgress = fn - return r -} - -// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) -func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request { - if r.err != nil { - return r - } - if maxBytes < 0 { - r.err = fmt.Errorf("max response body bytes must be >= 0") - return r - } - r.config.MaxRespBodyBytes = maxBytes - return r -} - -// AddQuery 添加查询参数 -func (r *Request) AddQuery(key, value string) *Request { - if r.err != nil { - return r - } - r.config.Queries[key] = append(r.config.Queries[key], value) - return r -} - -// SetQuery 设置查询参数(覆盖) -func (r *Request) SetQuery(key, value string) *Request { - if r.err != nil { - return r - } - r.config.Queries[key] = []string{value} - return r -} - -// SetQueries 设置所有查询参数(覆盖) -func (r *Request) SetQueries(queries map[string][]string) *Request { - if r.err != nil { - return r - } - r.config.Queries = cloneStringMapSlice(queries) - return r -} - -// AddQueries 批量添加查询参数 -func (r *Request) AddQueries(queries map[string]string) *Request { - if r.err != nil { - return r - } - for k, v := range queries { - r.config.Queries[k] = append(r.config.Queries[k], v) - } - return r -} - -// DeleteQuery 删除查询参数 -func (r *Request) DeleteQuery(key string) *Request { - if r.err != nil { - return r - } - delete(r.config.Queries, key) - return r -} - -// DeleteQueryValue 删除查询参数的特定值 -func (r *Request) DeleteQueryValue(key, value string) *Request { - if r.err != nil { - return r - } - - values, ok := r.config.Queries[key] - if !ok { - return r - } - - newValues := make([]string, 0, len(values)) - for _, v := range values { - if v != value { - newValues = append(newValues, v) - } - } - - if len(newValues) == 0 { - delete(r.config.Queries, key) - } else { - r.config.Queries[key] = newValues - } - - return r -} diff --git a/request_execution.go b/request_execution.go new file mode 100644 index 0000000..2c97595 --- /dev/null +++ b/request_execution.go @@ -0,0 +1,34 @@ +package starnet + +import "net/http" + +// SetBasicAuth 设置 Basic 认证 +func (r *Request) SetBasicAuth(username, password string) *Request { + return r.applyMutation(mutateBasicAuth(username, password)) +} + +// SetContentLength 设置 Content-Length +func (r *Request) SetContentLength(length int64) *Request { + return r.applyMutation(mutateContentLength(length)) +} + +// SetAutoCalcContentLength 设置是否自动计算 Content-Length +// 警告:启用后会将整个 body 读入内存 +func (r *Request) SetAutoCalcContentLength(auto bool) *Request { + return r.applyMutation(mutateAutoCalcContentLength(auto)) +} + +// SetTransport 设置自定义 Transport +func (r *Request) SetTransport(transport *http.Transport) *Request { + return r.applyMutation(mutateTransport(transport)) +} + +// SetUploadProgress 设置文件上传进度回调 +func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request { + return r.applyMutation(mutateUploadProgress(fn)) +} + +// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) +func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request { + return r.applyMutation(mutateMaxRespBodyBytes(maxBytes)) +} diff --git a/request_execution_regression_test.go b/request_execution_regression_test.go new file mode 100644 index 0000000..41135fc --- /dev/null +++ b/request_execution_regression_test.go @@ -0,0 +1,172 @@ +package starnet + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" +) + +func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) { + var attempts int32 + + req := NewSimpleRequest("http://example.com/path", http.MethodPost). + SetHeader("X-Test", "one"). + AddQuery("q", "v"). + SetBodyReader(strings.NewReader("payload")) + req.client = &Client{client: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if err := r.Context().Err(); err != nil { + t.Fatalf("request context already done: %v", err) + } + if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" { + t.Fatalf("header values=%v", values) + } + if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" { + t.Fatalf("query values=%v", values) + } + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + _ = r.Body.Close() + if string(body) != "payload" { + t.Fatalf("body=%q", string(body)) + } + + n := atomic.AddInt32(&attempts, 1) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))), + Request: r, + }, nil + }), + }} + + resp1, err := req.Do() + if err != nil { + t.Fatalf("first Do() error: %v", err) + } + if err := resp1.Close(); err != nil { + t.Fatalf("first Close() error: %v", err) + } + + resp2, err := req.Do() + if err != nil { + t.Fatalf("second Do() error: %v", err) + } + defer resp2.Close() + + if got := atomic.LoadInt32(&attempts); got != 2 { + t.Fatalf("attempts=%d; want 2", got) + } +} + +func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + + req := NewSimpleRequest("", http.MethodGet). + SetRawRequest(rawReq). + SetProxy("http://proxy.example:8080"). + SetCustomIP([]string{"127.0.0.1"}). + SetSkipTLSVerify(true). + SetTLSServerName("override.example") + + if err := req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + + raw := req.execCtx.Value(ctxKeyRequestContext) + rc, ok := raw.(*RequestContext) + if !ok || rc == nil { + t.Fatalf("expected request context, got %#v", raw) + } + if rc.Proxy != "http://proxy.example:8080" { + t.Fatalf("proxy=%q", rc.Proxy) + } + if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" { + t.Fatalf("custom ip=%v", rc.CustomIP) + } + if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify { + t.Fatalf("tls config=%#v", rc.TLSConfig) + } + if rc.TLSServerName != "override.example" { + t.Fatalf("tls server name=%q", rc.TLSServerName) + } +} + +func TestRequestSetFormDataOverridesBytesBody(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodPost). + SetBodyString("stale"). + SetFormData(map[string][]string{"k": []string{"v"}}) + + if req.config.Body.Mode != bodyModeForm { + t.Fatalf("body mode=%v", req.config.Body.Mode) + } + if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 { + t.Fatalf("unexpected stale body state: %#v", req.config.Body) + } + + if err := req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + + body, err := req.httpReq.GetBody() + if err != nil { + t.Fatalf("GetBody() error: %v", err) + } + defer body.Close() + + data, err := io.ReadAll(body) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + if string(data) != "k=v" { + t.Fatalf("body=%q; want k=v", string(data)) + } +} + +func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "payload.txt") + if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + req := NewSimpleRequest("http://example.com", http.MethodPost). + SetJSON(map[string]string{"old": "json-only"}). + AddFile("file", filePath) + + if req.config.Body.Mode != bodyModeMultipart { + t.Fatalf("body mode=%v", req.config.Body.Mode) + } + if req.config.Body.Reader != nil || req.config.Body.Bytes != nil { + t.Fatalf("unexpected stale simple body state: %#v", req.config.Body) + } + + if err := req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + + data, err := io.ReadAll(req.httpReq.Body) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") { + t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type")) + } + if !strings.Contains(string(data), "file-body") { + t.Fatalf("multipart body missing file content: %q", string(data)) + } + if strings.Contains(string(data), "json-only") { + t.Fatalf("multipart body still contains stale json: %q", string(data)) + } +} diff --git a/request_header.go b/request_header.go index 68ae11e..1e75926 100644 --- a/request_header.go +++ b/request_header.go @@ -4,6 +4,25 @@ import ( "net/http" ) +func isHostHeaderKey(key string) bool { + return http.CanonicalHeaderKey(key) == "Host" +} + +func setRequestHostConfig(config *RequestConfig, host string) { + if config == nil { + return + } + if config.Headers == nil { + config.Headers = make(http.Header) + } + config.Host = host + if host == "" { + config.Headers.Del("Host") + return + } + config.Headers.Set("Host", host) +} + // SetHeader 设置 Header(覆盖) func (r *Request) SetHeader(key, value string) *Request { if r.err != nil { @@ -12,7 +31,11 @@ func (r *Request) SetHeader(key, value string) *Request { if r.doRaw { return r } + if isHostHeaderKey(key) { + return r.SetHost(value) + } r.config.Headers.Set(key, value) + r.invalidatePreparedState() return r } @@ -24,7 +47,11 @@ func (r *Request) AddHeader(key, value string) *Request { if r.doRaw { return r } + if isHostHeaderKey(key) { + return r.SetHost(value) + } r.config.Headers.Add(key, value) + r.invalidatePreparedState() return r } @@ -37,6 +64,9 @@ func (r *Request) SetHeaders(headers http.Header) *Request { return r } r.config.Headers = cloneHeader(headers) + r.config.Host = r.config.Headers.Get("Host") + r.syncRequestHost() + r.invalidatePreparedState() return r } @@ -49,8 +79,14 @@ func (r *Request) AddHeaders(headers map[string]string) *Request { return r } for k, v := range headers { + if isHostHeaderKey(k) { + setRequestHostConfig(r.config, v) + continue + } r.config.Headers.Add(k, v) } + r.syncRequestHost() + r.invalidatePreparedState() return r } @@ -62,18 +98,56 @@ func (r *Request) DeleteHeader(key string) *Request { if r.doRaw { return r } + if isHostHeaderKey(key) { + setRequestHostConfig(r.config, "") + r.syncRequestHost() + r.invalidatePreparedState() + return r + } r.config.Headers.Del(key) + r.invalidatePreparedState() return r } // GetHeader 获取 Header func (r *Request) GetHeader(key string) string { + if isHostHeaderKey(key) { + return r.config.Host + } return r.config.Headers.Get(key) } // Headers 获取所有 Headers func (r *Request) Headers() http.Header { - return r.config.Headers + if r == nil || r.config == nil { + return make(http.Header) + } + return cloneHeader(r.config.Headers) +} + +// SetHost 设置请求 Host 头覆盖。 +func (r *Request) SetHost(host string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + setRequestHostConfig(r.config, host) + r.syncRequestHost() + r.invalidatePreparedState() + return r +} + +// Host 获取显式 Host 覆盖。 +func (r *Request) Host() string { + if r.config != nil && r.config.Host != "" { + return r.config.Host + } + if r.httpReq != nil { + return r.httpReq.Host + } + return "" } // SetContentType 设置 Content-Type @@ -104,7 +178,8 @@ func (r *Request) AddCookie(cookie *http.Cookie) *Request { if r.doRaw { return r } - r.config.Cookies = append(r.config.Cookies, cookie) + r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie)) + r.invalidatePreparedState() return r } @@ -134,7 +209,8 @@ func (r *Request) SetCookies(cookies []*http.Cookie) *Request { if r.doRaw { return r } - r.config.Cookies = cookies + r.config.Cookies = cloneCookies(cookies) + r.invalidatePreparedState() return r } @@ -153,12 +229,16 @@ func (r *Request) AddCookies(cookies map[string]string) *Request { Path: "/", }) } + r.invalidatePreparedState() return r } // Cookies 获取所有 Cookies func (r *Request) Cookies() []*http.Cookie { - return r.config.Cookies + if r == nil || r.config == nil { + return nil + } + return cloneCookies(r.config.Cookies) } // ResetHeaders 重置所有 Headers @@ -167,6 +247,9 @@ func (r *Request) ResetHeaders() *Request { return r } r.config.Headers = make(http.Header) + r.config.Host = "" + r.syncRequestHost() + r.invalidatePreparedState() return r } @@ -176,5 +259,6 @@ func (r *Request) ResetCookies() *Request { return r } r.config.Cookies = []*http.Cookie{} + r.invalidatePreparedState() return r } diff --git a/request_multipart.go b/request_multipart.go new file mode 100644 index 0000000..cf7f04a --- /dev/null +++ b/request_multipart.go @@ -0,0 +1,69 @@ +package starnet + +import ( + "context" + "io" + "mime/multipart" + "os" +) + +// applyMultipartBody 应用 multipart 请求体 +func (r *Request) applyMultipartBody(execCtx context.Context) error { + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + + r.httpReq.Header.Set("Content-Type", writer.FormDataContentType()) + r.httpReq.Body = pr + + go func() { + defer pw.Close() + defer writer.Close() + + for key, values := range r.config.Body.FormData { + for _, value := range values { + if err := writer.WriteField(key, value); err != nil { + pw.CloseWithError(wrapError(err, "write form field")) + return + } + } + } + + for _, file := range r.config.Body.Files { + if err := r.writeFile(execCtx, writer, file); err != nil { + pw.CloseWithError(err) + return + } + } + }() + + return nil +} + +// writeFile 写入文件到 multipart writer +func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error { + part, err := writer.CreateFormFile(file.FormName, file.FileName) + if err != nil { + return wrapError(err, "create form file") + } + + var reader io.Reader + if file.FileData != nil { + reader = file.FileData + } else if file.FilePath != "" { + f, err := os.Open(file.FilePath) + if err != nil { + return wrapError(err, "open file") + } + defer f.Close() + reader = f + } else { + return ErrNilReader + } + + _, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress) + if err != nil { + return wrapError(err, "copy file data") + } + + return nil +} diff --git a/request_mutation.go b/request_mutation.go new file mode 100644 index 0000000..2e33e26 --- /dev/null +++ b/request_mutation.go @@ -0,0 +1,326 @@ +package starnet + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "time" +) + +type requestMutation func(*Request) error + +func (r *Request) applyMutation(mutation requestMutation) *Request { + if r == nil || r.err != nil { + return r + } + if err := mutation(r); err != nil { + r.err = err + return r + } + r.invalidatePreparedState() + return r +} + +func requestOptFromMutation(mutation requestMutation) RequestOpt { + return func(r *Request) error { + if r == nil { + return nil + } + return mutation(r) + } +} + +func validateCustomIPs(ips []string) error { + for _, ip := range ips { + if net.ParseIP(ip) == nil { + return wrapError(ErrInvalidIP, "ip: %s", ip) + } + } + return nil +} + +func validateCustomDNS(dnsServers []string) error { + for _, dns := range dnsServers { + if net.ParseIP(dns) == nil { + return wrapError(ErrInvalidDNS, "dns: %s", dns) + } + } + return nil +} + +func parseProxyURL(proxy string) (*url.URL, error) { + if proxy == "" { + return nil, nil + } + + proxyURL, err := url.Parse(proxy) + if err != nil { + return nil, wrapError(err, "parse proxy url") + } + if proxyURL.Scheme == "" { + return nil, fmt.Errorf("proxy scheme is required: %s", proxy) + } + if proxyURL.Host == "" { + return nil, fmt.Errorf("proxy host is required: %s", proxy) + } + + return proxyURL, nil +} + +func mutateTimeout(timeout time.Duration) requestMutation { + return func(r *Request) error { + r.config.Network.Timeout = timeout + return nil + } +} + +func mutateDialTimeout(timeout time.Duration) requestMutation { + return func(r *Request) error { + r.config.Network.DialTimeout = timeout + return nil + } +} + +func mutateProxy(proxy string) requestMutation { + return func(r *Request) error { + if _, err := parseProxyURL(proxy); err != nil { + return err + } + r.config.Network.Proxy = proxy + return nil + } +} + +func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation { + return func(r *Request) error { + r.config.Network.DialFunc = fn + return nil + } +} + +func mutateTLSConfig(tlsConfig *tls.Config) requestMutation { + return func(r *Request) error { + r.config.TLS.Config = tlsConfig + return nil + } +} + +func mutateTLSServerName(serverName string) requestMutation { + return func(r *Request) error { + r.config.TLS.ServerName = serverName + return nil + } +} + +func mutateTraceHooks(hooks *TraceHooks) requestMutation { + return func(r *Request) error { + r.traceHooks = hooks + return nil + } +} + +func mutateSkipTLSVerify(skip bool) requestMutation { + return func(r *Request) error { + r.config.TLS.SkipVerify = skip + return nil + } +} + +func mutateCustomIP(ips []string) requestMutation { + return func(r *Request) error { + if err := validateCustomIPs(ips); err != nil { + return err + } + r.config.DNS.CustomIP = cloneStringSlice(ips) + return nil + } +} + +func mutateAddCustomIP(ip string) requestMutation { + return func(r *Request) error { + if err := validateCustomIPs([]string{ip}); err != nil { + return err + } + r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) + return nil + } +} + +func mutateCustomDNS(dnsServers []string) requestMutation { + return func(r *Request) error { + if err := validateCustomDNS(dnsServers); err != nil { + return err + } + r.config.DNS.CustomDNS = cloneStringSlice(dnsServers) + return nil + } +} + +func mutateAddCustomDNS(dns string) requestMutation { + return func(r *Request) error { + if err := validateCustomDNS([]string{dns}); err != nil { + return err + } + r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) + return nil + } +} + +func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation { + return func(r *Request) error { + r.config.DNS.LookupFunc = fn + return nil + } +} + +func mutateBasicAuth(username, password string) requestMutation { + return func(r *Request) error { + r.config.BasicAuth = [2]string{username, password} + return nil + } +} + +func mutateContentLength(length int64) requestMutation { + return func(r *Request) error { + r.config.ContentLength = length + return nil + } +} + +func mutateAutoCalcContentLength(auto bool) requestMutation { + return func(r *Request) error { + if r.doRaw { + return fmt.Errorf("cannot set auto calc content length in raw mode") + } + r.config.AutoCalcContentLength = auto + return nil + } +} + +func mutateTransport(transport *http.Transport) requestMutation { + return func(r *Request) error { + r.config.Transport = transport + r.config.CustomTransport = true + return nil + } +} + +func mutateUploadProgress(fn UploadProgressFunc) requestMutation { + return func(r *Request) error { + r.config.UploadProgress = fn + return nil + } +} + +func mutateAutoFetch(auto bool) requestMutation { + return func(r *Request) error { + r.autoFetch = auto + return nil + } +} + +func mutateMaxRespBodyBytes(maxBytes int64) requestMutation { + return func(r *Request) error { + if maxBytes < 0 { + return fmt.Errorf("max response body bytes must be >= 0") + } + r.config.MaxRespBodyBytes = maxBytes + return nil + } +} + +func mutateContext(ctx context.Context) requestMutation { + return func(r *Request) error { + ctx = normalizeContext(ctx) + r.ctx = ctx + if r.doRaw && r.rawTemplate != nil { + r.rawTemplate = r.rawTemplate.WithContext(ctx) + } + if r.httpReq != nil { + r.httpReq = r.httpReq.WithContext(ctx) + } + return nil + } +} + +func mutateRawRequest(httpReq *http.Request) requestMutation { + return func(r *Request) error { + if httpReq == nil { + return fmt.Errorf("httpReq cannot be nil") + } + r.httpReq = httpReq + r.rawTemplate = httpReq + r.ctx = normalizeContext(httpReq.Context()) + r.method = httpReq.Method + if httpReq.URL != nil { + r.url = httpReq.URL.String() + } + r.doRaw = true + r.rawSourceExternal = true + return nil + } +} + +func mutateAddQuery(key, value string) requestMutation { + return func(r *Request) error { + r.config.Queries[key] = append(r.config.Queries[key], value) + return nil + } +} + +func mutateSetQuery(key, value string) requestMutation { + return func(r *Request) error { + r.config.Queries[key] = []string{value} + return nil + } +} + +func mutateSetQueries(queries map[string][]string) requestMutation { + return func(r *Request) error { + r.config.Queries = cloneStringMapSlice(queries) + return nil + } +} + +func mutateAddQueries(queries map[string]string) requestMutation { + return func(r *Request) error { + for key, value := range queries { + r.config.Queries[key] = append(r.config.Queries[key], value) + } + return nil + } +} + +func mutateDeleteQuery(key string) requestMutation { + return func(r *Request) error { + delete(r.config.Queries, key) + return nil + } +} + +func mutateDeleteQueryValue(key, value string) requestMutation { + return func(r *Request) error { + values, ok := r.config.Queries[key] + if !ok { + return nil + } + + newValues := make([]string, 0, len(values)) + for _, item := range values { + if item != value { + newValues = append(newValues, item) + } + } + + if len(newValues) == 0 { + delete(r.config.Queries, key) + return nil + } + + r.config.Queries[key] = newValues + return nil + } +} diff --git a/request_network.go b/request_network.go new file mode 100644 index 0000000..772c78a --- /dev/null +++ b/request_network.go @@ -0,0 +1,71 @@ +package starnet + +import ( + "context" + "crypto/tls" + "net" + "time" +) + +// SetTimeout 设置请求总超时时间 +// timeout > 0: 为本次请求注入 context 超时 +// timeout = 0: 不额外设置请求总超时 +// timeout < 0: 禁用 starnet 默认总超时 +func (r *Request) SetTimeout(timeout time.Duration) *Request { + return r.applyMutation(mutateTimeout(timeout)) +} + +// SetDialTimeout 设置连接超时时间 +func (r *Request) SetDialTimeout(timeout time.Duration) *Request { + return r.applyMutation(mutateDialTimeout(timeout)) +} + +// SetProxy 设置代理 +func (r *Request) SetProxy(proxy string) *Request { + return r.applyMutation(mutateProxy(proxy)) +} + +// SetDialFunc 设置自定义 Dial 函数 +func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request { + return r.applyMutation(mutateDialFunc(fn)) +} + +// SetTLSConfig 设置 TLS 配置 +func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request { + return r.applyMutation(mutateTLSConfig(tlsConfig)) +} + +// SetTLSServerName 设置显式 TLS ServerName/SNI。 +func (r *Request) SetTLSServerName(serverName string) *Request { + return r.applyMutation(mutateTLSServerName(serverName)) +} + +// SetSkipTLSVerify 设置是否跳过 TLS 验证 +func (r *Request) SetSkipTLSVerify(skip bool) *Request { + return r.applyMutation(mutateSkipTLSVerify(skip)) +} + +// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS) +func (r *Request) SetCustomIP(ips []string) *Request { + return r.applyMutation(mutateCustomIP(ips)) +} + +// AddCustomIP 添加自定义 IP +func (r *Request) AddCustomIP(ip string) *Request { + return r.applyMutation(mutateAddCustomIP(ip)) +} + +// SetCustomDNS 设置自定义 DNS 服务器 +func (r *Request) SetCustomDNS(dnsServers []string) *Request { + return r.applyMutation(mutateCustomDNS(dnsServers)) +} + +// AddCustomDNS 添加自定义 DNS 服务器 +func (r *Request) AddCustomDNS(dns string) *Request { + return r.applyMutation(mutateAddCustomDNS(dns)) +} + +// SetLookupFunc 设置自定义 DNS 解析函数 +func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request { + return r.applyMutation(mutateLookupFunc(fn)) +} diff --git a/request_prepare.go b/request_prepare.go new file mode 100644 index 0000000..efcb7a5 --- /dev/null +++ b/request_prepare.go @@ -0,0 +1,314 @@ +package starnet + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptrace" + "net/url" + "strings" +) + +func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) { + if httpReq == nil { + return + } + httpReq.Body = io.NopCloser(bytes.NewReader(data)) + httpReq.ContentLength = int64(len(data)) + httpReq.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } +} + +func clearSimpleBodyState(body *BodyConfig) { + if body == nil { + return + } + body.Bytes = nil + body.Reader = nil +} + +func resetFormBodyState(body *BodyConfig) { + if body == nil { + return + } + body.FormData = make(map[string][]string) +} + +func resetMultipartBodyState(body *BodyConfig) { + if body == nil { + return + } + body.Files = nil +} + +func setBytesBodyConfig(body *BodyConfig, data []byte) { + if body == nil { + return + } + body.Mode = bodyModeBytes + body.Bytes = cloneBytes(data) + body.Reader = nil + resetFormBodyState(body) + resetMultipartBodyState(body) +} + +func setReaderBodyConfig(body *BodyConfig, reader io.Reader) { + if body == nil { + return + } + body.Mode = bodyModeReader + body.Reader = reader + body.Bytes = nil + resetFormBodyState(body) + resetMultipartBodyState(body) +} + +func setFormBodyConfig(body *BodyConfig, data map[string][]string) { + if body == nil { + return + } + body.Mode = bodyModeForm + clearSimpleBodyState(body) + resetMultipartBodyState(body) + body.FormData = cloneStringMapSlice(data) +} + +func ensureFormMode(body *BodyConfig) { + if body == nil { + return + } + if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart { + if body.FormData == nil { + body.FormData = make(map[string][]string) + } + return + } + clearSimpleBodyState(body) + resetMultipartBodyState(body) + body.FormData = make(map[string][]string) + body.Mode = bodyModeForm +} + +func ensureMultipartMode(body *BodyConfig) { + if body == nil { + return + } + if body.Mode == bodyModeMultipart { + if body.FormData == nil { + body.FormData = make(map[string][]string) + } + return + } + if body.Mode != bodyModeForm { + clearSimpleBodyState(body) + body.FormData = make(map[string][]string) + } + body.Mode = bodyModeMultipart + if body.FormData == nil { + body.FormData = make(map[string][]string) + } +} + +func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) { + if reader == nil { + return nil, nil + } + data := make([]byte, reader.Len()) + _, err := reader.ReadAt(data, reader.Size()-int64(reader.Len())) + if err != nil && err != io.EOF { + return nil, err + } + return data, nil +} + +func snapshotStringReader(reader *strings.Reader) ([]byte, error) { + if reader == nil { + return nil, nil + } + data := make([]byte, reader.Len()) + _, err := reader.ReadAt(data, reader.Size()-int64(reader.Len())) + if err != nil && err != io.EOF { + return nil, err + } + return data, nil +} + +// applyBody 应用请求体 +func (r *Request) applyBody(execCtx context.Context) error { + r.httpReq.Body = nil + r.httpReq.GetBody = nil + r.httpReq.ContentLength = 0 + + switch r.config.Body.Mode { + case bodyModeReader: + if r.config.Body.Reader == nil { + return nil + } + switch reader := r.config.Body.Reader.(type) { + case *bytes.Buffer: + setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...)) + case *bytes.Reader: + data, err := snapshotBytesReader(reader) + if err != nil { + return wrapError(err, "snapshot bytes reader") + } + setReplayableRequestBodyBytes(r.httpReq, data) + case *strings.Reader: + data, err := snapshotStringReader(reader) + if err != nil { + return wrapError(err, "snapshot strings reader") + } + setReplayableRequestBodyBytes(r.httpReq, data) + default: + r.httpReq.Body = io.NopCloser(r.config.Body.Reader) + } + switch reader := r.config.Body.Reader.(type) { + case *bytes.Buffer: + r.httpReq.ContentLength = int64(reader.Len()) + case *bytes.Reader: + r.httpReq.ContentLength = int64(reader.Len()) + case *strings.Reader: + r.httpReq.ContentLength = int64(reader.Len()) + } + return nil + case bodyModeBytes: + setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes) + return nil + case bodyModeMultipart: + return r.applyMultipartBody(execCtx) + case bodyModeForm: + values := url.Values{} + for key, items := range r.config.Body.FormData { + for _, value := range items { + values.Add(key, value) + } + } + encoded := values.Encode() + setReplayableRequestBodyBytes(r.httpReq, []byte(encoded)) + return nil + } + + return nil +} + +// prepare 准备请求(应用配置) +func (r *Request) prepare() (err error) { + if r.applied { + return nil + } + + if r.httpReq == nil { + return fmt.Errorf("http request is nil") + } + + execCtx := r.ctx + if execCtx == nil { + execCtx = context.Background() + } + defaultTLSServerName := "" + if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" { + defaultTLSServerName = r.httpReq.URL.Hostname() + } + execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName) + + var traceState *traceState + if r.traceHooks != nil { + traceState = newTraceState(r.traceHooks) + execCtx = withTraceState(execCtx, traceState) + if clientTrace := traceState.clientTrace(); clientTrace != nil { + execCtx = httptrace.WithClientTrace(execCtx, clientTrace) + } + } + + var cancel context.CancelFunc + if r.config.Network.Timeout > 0 { + execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout) + } + defer func() { + if err != nil && cancel != nil { + cancel() + } + }() + + if r.httpClient == nil { + r.httpClient, err = r.buildHTTPClient() + if err != nil { + return err + } + } + + if !r.doRaw { + if len(r.config.Queries) > 0 { + query := r.httpReq.URL.Query() + for key, values := range r.config.Queries { + for _, value := range values { + query.Add(key, value) + } + } + r.httpReq.URL.RawQuery = query.Encode() + } + + for key, values := range r.config.Headers { + if isHostHeaderKey(key) { + continue + } + for _, value := range values { + r.httpReq.Header.Add(key, value) + } + } + + for _, cookie := range r.config.Cookies { + r.httpReq.AddCookie(cookie) + } + + if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { + r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) + } + + if err := r.applyBody(execCtx); err != nil { + return err + } + + if r.config.ContentLength > 0 { + r.httpReq.ContentLength = r.config.ContentLength + } else if r.config.ContentLength < 0 { + r.httpReq.ContentLength = 0 + } + + if r.config.AutoCalcContentLength && r.httpReq.Body != nil { + data, err := io.ReadAll(r.httpReq.Body) + if err != nil { + return wrapError(err, "read body for content length") + } + setReplayableRequestBodyBytes(r.httpReq, data) + } + + r.syncRequestHost() + } + + r.execCtx = execCtx + r.traceState = traceState + r.cancel = cancel + r.httpReq = r.httpReq.WithContext(r.execCtx) + r.applied = true + return nil +} + +// buildHTTPClient 构建 HTTP Client +func (r *Request) buildHTTPClient() (*http.Client, error) { + if r.client != nil { + return r.client.HTTPClient(), nil + } + + if r.config.CustomTransport && r.config.Transport != nil { + return &http.Client{ + Transport: &Transport{base: r.config.Transport}, + Timeout: 0, + }, nil + } + + return DefaultHTTPClient(), nil +} diff --git a/request_prepare_regression_test.go b/request_prepare_regression_test.go new file mode 100644 index 0000000..2cc2de0 --- /dev/null +++ b/request_prepare_regression_test.go @@ -0,0 +1,335 @@ +package starnet + +import ( + "bytes" + "context" + "errors" + "io" + "mime/multipart" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodPost). + SetHeader("X-Test", "one"). + SetBodyString("first") + req.client = &Client{client: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + _ = r.Body.Close() + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))), + Request: r, + }, nil + }), + }} + + if _, err := req.HTTPClient(); err != nil { + t.Fatalf("HTTPClient() error: %v", err) + } + + req.SetHeader("X-Test", "two").SetBodyString("second") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, err := resp.Body().String() + if err != nil { + t.Fatalf("Body().String() error: %v", err) + } + if body != "two:second" { + t.Fatalf("body=%q; want %q", body, "two:second") + } +} + +func TestRequestPreparedMutationReappliesTimeout(t *testing.T) { + var attempts int32 + req := NewSimpleRequest("http://example.com", http.MethodGet) + req.client = &Client{client: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + if atomic.AddInt32(&attempts, 1) == 1 { + return &http.Response{ + StatusCode: http.StatusNoContent, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + Request: r, + }, nil + } + select { + case <-time.After(50 * time.Millisecond): + return &http.Response{ + StatusCode: http.StatusNoContent, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + Request: r, + }, nil + case <-r.Context().Done(): + return nil, r.Context().Err() + } + }), + }} + + resp, err := req.Do() + if err != nil { + t.Fatalf("first Do() error: %v", err) + } + _ = resp.Close() + + _, err = req.SetTimeout(10 * time.Millisecond).Do() + if err == nil { + t.Fatal("second Do() succeeded; want timeout error") + } + if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second Do() error=%v; want timeout", err) + } +} + +func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodPost) + + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, pr) + _ = pr.Close() + close(done) + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := req.writeFile(ctx, writer, RequestFile{ + FormName: "file", + FileName: "payload.txt", + FileData: strings.NewReader("payload"), + FileSize: int64(len("payload")), + }) + _ = writer.Close() + _ = pw.Close() + <-done + + if !errors.Is(err, context.Canceled) { + t.Fatalf("writeFile() error=%v; want context.Canceled", err) + } +} + +func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err) + } +} + +func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) { + tests := []struct { + name string + req *Request + want string + }{ + { + name: "bytes", + req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")), + want: "payload", + }, + { + name: "bytes-reader", + req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))), + want: "payload", + }, + { + name: "strings-reader", + req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")), + want: "payload", + }, + { + name: "form-data", + req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"), + want: "k=v", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + if tt.req.httpReq.GetBody == nil { + t.Fatal("GetBody is nil") + } + + body, err := tt.req.httpReq.GetBody() + if err != nil { + t.Fatalf("GetBody() error: %v", err) + } + defer body.Close() + + data, err := io.ReadAll(body) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + if string(data) != tt.want { + t.Fatalf("body=%q; want %q", string(data), tt.want) + } + }) + } +} + +type replayRoundTripper struct { + attempts int + bodies []string +} + +func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + + rt.attempts++ + rt.bodies = append(rt.bodies, string(body)) + if rt.attempts == 1 { + return nil, errors.New("first target failed") + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: req, + }, nil +} + +func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) { + req := NewSimpleRequest("http://example.com/upload", http.MethodPut). + SetBodyReader(strings.NewReader("payload")) + + if err := req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + + rt := &replayRoundTripper{} + resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"}) + if err != nil { + t.Fatalf("roundTripResolvedTargets() error: %v", err) + } + defer resp.Body.Close() + + if len(rt.bodies) != 2 { + t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies) + } + if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" { + t.Fatalf("attempt bodies=%v; want both payload", rt.bodies) + } +} + +func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) { + req := NewSimpleRequest("http://example.com/upload", http.MethodPost). + SetBodyReader(strings.NewReader("payload")) + + if err := req.prepare(); err != nil { + t.Fatalf("prepare() error: %v", err) + } + + rt := &replayRoundTripper{} + _, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"}) + if err == nil { + t.Fatal("roundTripResolvedTargets() succeeded; want first target error") + } + if len(rt.bodies) != 1 { + t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies) + } + if rt.bodies[0] != "payload" { + t.Fatalf("attempt body=%q; want payload", rt.bodies[0]) + } +} + +func TestRetryReplayableReaderBody(t *testing.T) { + var attempts int32 + req := NewSimpleRequest("http://example.com/upload", http.MethodPut). + SetBodyReader(strings.NewReader("payload")). + SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)) + req.client = &Client{client: &http.Client{ + Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + _ = r.Body.Close() + if string(body) != "payload" { + t.Fatalf("body=%q; want payload", string(body)) + } + + if atomic.AddInt32(&attempts, 1) == 1 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("retry")), + Request: r, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: r, + }, nil + }), + }} + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if got := atomic.LoadInt32(&attempts); got != 2 { + t.Fatalf("attempts=%d; want 2", got) + } +} + +func TestWithProxyInvalidReturnsError(t *testing.T) { + _, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy")) + if err == nil { + t.Fatal("NewRequest() succeeded; want invalid proxy error") + } +} + +func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) { + client := NewClientNoErr() + + _, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy")) + if err == nil { + t.Fatal("Client.NewRequest() succeeded; want invalid proxy error") + } +} + +func TestNewClientWithInvalidProxyReturnsError(t *testing.T) { + _, err := NewClient(WithProxy("://bad-proxy")) + if err == nil { + t.Fatal("NewClient() succeeded; want invalid proxy error") + } +} diff --git a/request_query.go b/request_query.go new file mode 100644 index 0000000..356dab8 --- /dev/null +++ b/request_query.go @@ -0,0 +1,31 @@ +package starnet + +// AddQuery 添加查询参数 +func (r *Request) AddQuery(key, value string) *Request { + return r.applyMutation(mutateAddQuery(key, value)) +} + +// SetQuery 设置查询参数(覆盖) +func (r *Request) SetQuery(key, value string) *Request { + return r.applyMutation(mutateSetQuery(key, value)) +} + +// SetQueries 设置所有查询参数(覆盖) +func (r *Request) SetQueries(queries map[string][]string) *Request { + return r.applyMutation(mutateSetQueries(queries)) +} + +// AddQueries 批量添加查询参数 +func (r *Request) AddQueries(queries map[string]string) *Request { + return r.applyMutation(mutateAddQueries(queries)) +} + +// DeleteQuery 删除查询参数 +func (r *Request) DeleteQuery(key string) *Request { + return r.applyMutation(mutateDeleteQuery(key)) +} + +// DeleteQueryValue 删除查询参数的特定值 +func (r *Request) DeleteQueryValue(key, value string) *Request { + return r.applyMutation(mutateDeleteQueryValue(key, value)) +} diff --git a/request_state_boundary_test.go b/request_state_boundary_test.go new file mode 100644 index 0000000..d215e40 --- /dev/null +++ b/request_state_boundary_test.go @@ -0,0 +1,168 @@ +package starnet + +import ( + "io" + "net/http" + "net/url" + "strings" + "sync/atomic" + "testing" +) + +type stateRoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestSetContextNilUsesBackground(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet) + req.client = &Client{client: &http.Client{ + Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Context() == nil { + t.Fatal("request context is nil") + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: r, + }, nil + }), + }} + + resp, err := req.SetContext(nil).Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if req.Context() == nil { + t.Fatal("request Context() is nil") + } +} + +func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) { + var hits int32 + req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil)) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + req.client = &Client{client: &http.Client{ + Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Context() == nil { + t.Fatal("retry request context is nil") + } + if atomic.AddInt32(&hits, 1) == 1 { + return &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("retry")), + Request: r, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("ok")), + Request: r, + }, nil + }), + }} + + resp, err := req. + SetTimeout(DefaultTimeout). + SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if got := atomic.LoadInt32(&hits); got != 2 { + t.Fatalf("hits=%d; want 2", got) + } +} + +func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload")) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + rawReq.Header.Set("X-Test", "one") + + req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq) + cloned := req.Clone() + + if cloned.Err() != nil { + t.Fatalf("Clone() err = %v", cloned.Err()) + } + if cloned.RawRequest() == rawReq { + t.Fatal("raw request pointer reused") + } + + cloned.RawRequest().Header.Set("X-Test", "two") + if rawReq.Header.Get("X-Test") != "one" { + t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test")) + } + + body, err := cloned.RawRequest().GetBody() + if err != nil { + t.Fatalf("GetBody() error: %v", err) + } + defer body.Close() + + data, err := io.ReadAll(body) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + if string(data) != "payload" { + t.Fatalf("body=%q; want payload", string(data)) + } +} + +func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) { + rawReq := &http.Request{ + Method: http.MethodPost, + URL: mustParseURL(t, "http://example.com/upload"), + Header: make(http.Header), + Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))), + } + + req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq) + cloned := req.Clone() + + if cloned.Err() == nil { + t.Fatal("Clone() should fail for non-replayable raw body") + } + if !strings.Contains(cloned.Err().Error(), "non-replayable") { + t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err()) + } +} + +func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err != nil { + t.Fatalf("NewRequest() error: %v", err) + } + + req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode() + if req.Err() == nil { + t.Fatal("DisableRawMode() should set error") + } + if !strings.Contains(req.Err().Error(), "cannot disable raw mode") { + t.Fatalf("DisableRawMode() err=%v", req.Err()) + } + if !req.doRaw { + t.Fatal("request should remain in raw mode") + } +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + return parsed +} diff --git a/request_trace.go b/request_trace.go new file mode 100644 index 0000000..f052e28 --- /dev/null +++ b/request_trace.go @@ -0,0 +1,6 @@ +package starnet + +// SetTraceHooks 设置请求 trace 回调。 +func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request { + return r.applyMutation(mutateTraceHooks(hooks)) +} diff --git a/retry.go b/retry.go index 3e5e415..f63628b 100644 --- a/retry.go +++ b/retry.go @@ -1,6 +1,7 @@ package starnet import ( + "bytes" "context" "errors" "fmt" @@ -9,6 +10,7 @@ import ( "math/rand" "net" "net/http" + "strings" "time" ) @@ -87,6 +89,9 @@ func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) { return policy, nil } +// WithRetry 为请求启用自动重试。 +// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用, +// 以避免请求体已落地后再次发送。 func WithRetry(max int, opts ...RetryOpt) RequestOpt { return func(r *Request) error { policy, err := buildRetryPolicy(max, opts...) @@ -98,6 +103,9 @@ func WithRetry(max int, opts ...RetryOpt) RequestOpt { } } +// SetRetry 为请求启用自动重试。 +// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用, +// 以避免请求体已落地后再次发送。 func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request { if r.err != nil { return r @@ -226,10 +234,10 @@ func (r *Request) doWithRetry() (*Response, error) { return r.doOnce() } - retryCtx := r.ctx + retryCtx := normalizeContext(r.ctx) retryCancel := func() {} if r.config.Network.Timeout > 0 { - retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout) + retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout) } defer retryCancel() @@ -238,6 +246,12 @@ func (r *Request) doWithRetry() (*Response, error) { var lastErr error for attempt := 0; attempt < maxAttempts; attempt++ { + attemptNo := attempt + 1 + emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{ + Attempt: attemptNo, + MaxAttempts: maxAttempts, + }) + attemptReq, err := r.newRetryAttempt(retryCtx) if err != nil { return nil, wrapError(err, "build retry attempt") @@ -248,7 +262,19 @@ func (r *Request) doWithRetry() (*Response, error) { resp.request = r } - if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) { + willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{ + Attempt: attemptNo, + MaxAttempts: maxAttempts, + StatusCode: statusCode, + Err: err, + WillRetry: willRetry, + }) + if !willRetry { return resp, err } @@ -262,6 +288,10 @@ func (r *Request) doWithRetry() (*Response, error) { if delay <= 0 { continue } + emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{ + Attempt: attemptNo, + Delay: delay, + }) timer := time.NewTimer(delay) select { @@ -293,19 +323,9 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) { return attempt, nil } - if r.httpReq == nil { - return nil, fmt.Errorf("http request is nil") - } - - raw := r.httpReq.Clone(ctx) - if r.httpReq.GetBody != nil { - body, err := r.httpReq.GetBody() - if err != nil { - return nil, wrapError(err, "get raw request body") - } - raw.Body = body - } else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody { - return nil, fmt.Errorf("raw request body is not replayable") + raw, err := cloneRawHTTPRequest(r.httpReq, ctx) + if err != nil { + return nil, err } attempt.httpReq = raw @@ -316,6 +336,9 @@ func (p *retryPolicy) canRetryRequest(r *Request) bool { if p.idempotentOnly && !isIdempotentMethod(r.method) { return false } + if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) { + return false + } return isReplayableRequest(r) } @@ -347,20 +370,40 @@ func isReplayableRequest(r *Request) bool { return false } - // Reader / stream body 通常不可重放,保守地不重试。 - if r.config.Body.Reader != nil { + return isReplayableConfiguredBody(r.config.Body) +} + +func hasReaderRequestBody(r *Request) bool { + if r == nil || r.config == nil { return false } + return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil +} - for _, f := range r.config.Body.Files { - if f.FileData != nil || f.FilePath == "" { - return false +func isReplayableConfiguredBody(body BodyConfig) bool { + switch body.Mode { + case bodyModeReader: + return isReplayableBodyReader(body.Reader) + case bodyModeMultipart: + for _, file := range body.Files { + if file.FileData != nil || file.FilePath == "" { + return false + } } } return true } +func isReplayableBodyReader(reader io.Reader) bool { + switch reader.(type) { + case *bytes.Buffer, *bytes.Reader, *strings.Reader: + return true + default: + return false + } +} + func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool { if attempt >= maxAttempts-1 { return false diff --git a/review_regression_test.go b/review_regression_test.go new file mode 100644 index 0000000..c9a7b3a --- /dev/null +++ b/review_regression_test.go @@ -0,0 +1,244 @@ +package starnet + +import ( + "context" + "fmt" + "net" + "net/http" + "strconv" + "sync" + "testing" + "time" +) + +func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) { + tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer tlsServer.Close() + + _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) + if err != nil { + t.Fatalf("split tls server addr: %v", err) + } + + firstTarget := net.JoinHostPort("127.0.0.2", port) + secondTarget := net.JoinHostPort("127.0.0.1", port) + + var ( + mu sync.Mutex + connectTargets []string + ) + proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodConnect { + http.Error(w, "connect required", http.StatusMethodNotAllowed) + return + } + + mu.Lock() + connectTargets = append(connectTargets, r.Host) + mu.Unlock() + + if r.Host == firstTarget { + http.Error(w, "first target failed", http.StatusBadGateway) + return + } + + targetConn, err := net.Dial("tcp", r.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + targetConn.Close() + t.Fatal("proxy response writer is not a hijacker") + } + + clientConn, rw, err := hijacker.Hijack() + if err != nil { + targetConn.Close() + t.Fatalf("hijack proxy conn: %v", err) + } + if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { + clientConn.Close() + targetConn.Close() + t.Fatalf("write connect response: %v", err) + } + if err := rw.Flush(); err != nil { + clientConn.Close() + targetConn.Close() + t.Fatalf("flush connect response: %v", err) + } + + relayProxyConns(clientConn, targetConn) + })) + defer proxyServer.Close() + + reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port) + resp, err := NewSimpleRequest(reqURL, http.MethodGet). + SetProxy(proxyServer.URL). + SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}). + SetSkipTLSVerify(true). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if len(connectTargets) != 2 { + t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets) + } + if connectTargets[0] != firstTarget { + t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget) + } + if connectTargets[1] != secondTarget { + t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget) + } +} + +func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr() error: %v", err) + } + + var ( + mu sync.Mutex + dnsStartCount int + dnsDoneCount int + lastHost string + ) + hooks := &TraceHooks{ + DNSStart: func(info TraceDNSStartInfo) { + mu.Lock() + dnsStartCount++ + lastHost = info.Host + mu.Unlock() + }, + DNSDone: func(info TraceDNSDoneInfo) { + mu.Lock() + dnsDoneCount++ + mu.Unlock() + if info.Err != nil { + t.Errorf("unexpected dns error: %v", info.Err) + } + }, + } + + reqURL := "http://localhost:" + strconv.Itoa(addr.Port) + resp, err := NewSimpleRequest(reqURL, http.MethodGet). + SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if dnsStartCount != 1 { + t.Fatalf("dnsStartCount=%d", dnsStartCount) + } + if dnsDoneCount != 1 { + t.Fatalf("dnsDoneCount=%d", dnsDoneCount) + } + if lastHost != "localhost" { + t.Fatalf("lastHost=%q; want localhost", lastHost) + } +} + +func TestRequestHeadersReturnsCopy(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet). + SetHeader("X-Test", "one"). + SetHost("origin.example") + + headers := req.Headers() + headers.Set("X-Test", "two") + headers.Set("Host", "mutated.example") + + if got := req.GetHeader("X-Test"); got != "one" { + t.Fatalf("request header=%q; want one", got) + } + if got := req.Host(); got != "origin.example" { + t.Fatalf("request host=%q; want origin.example", got) + } +} + +func TestRequestCookiesIsolation(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet) + source := []*http.Cookie{{ + Name: "session", + Value: "one", + Path: "/", + }} + + req.SetCookies(source) + source[0].Value = "mutated-outside" + + got := req.Cookies() + if len(got) != 1 || got[0].Value != "one" { + t.Fatalf("cookies after SetCookies=%v", got) + } + + got[0].Value = "mutated-copy" + if latest := req.Cookies()[0].Value; latest != "one" { + t.Fatalf("internal cookie mutated via getter, got %q", latest) + } + + cookie := &http.Cookie{Name: "auth", Value: "token"} + req.ResetCookies().AddCookie(cookie) + cookie.Value = "changed" + + latest := req.Cookies() + if len(latest) != 1 || latest[0].Value != "token" { + t.Fatalf("cookies after AddCookie=%v", latest) + } +} + +func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr() error: %v", err) + } + + var dnsStartCount int + var dnsDoneCount int + hooks := &TraceHooks{ + DNSStart: func(info TraceDNSStartInfo) { + dnsStartCount++ + }, + DNSDone: func(info TraceDNSDoneInfo) { + dnsDoneCount++ + }, + } + + resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet). + SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: addr.IP}}, nil + }). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if dnsStartCount != 1 || dnsDoneCount != 1 { + t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount) + } +} diff --git a/tls_test.go b/tls_test.go index 01e4bb5..af00874 100644 --- a/tls_test.go +++ b/tls_test.go @@ -104,7 +104,34 @@ func TestRequestLevelTLSOverride(t *testing.T) { } func TestRequestTls(t *testing.T) { - resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do() + var requestCount int + server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + switch requestCount { + case 1: + if r.Header.Get("Hello") != "" { + t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello")) + } + if auth := r.Header.Get("Authorization"); auth != "" { + t.Fatalf("unexpected authorization on first request: %q", auth) + } + case 2: + if got := r.Header.Get("Hello"); got != "world" { + t.Fatalf("hello header=%q; want world", got) + } + if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" { + t.Fatalf("authorization=%q; want bearer token", got) + } + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + })) + defer server.Close() + + localURL := httpsURLForHost(t, server, "localhost") + resp, err := NewSimpleRequest(localURL, "GET"). + SetTLSConfig(&tls.Config{RootCAs: pool}). + Do() if err != nil { t.Fatalf("Do() error: %v", err) } @@ -114,11 +141,13 @@ func TestRequestTls(t *testing.T) { t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) } t.Logf("Response: %v", resp.Body().MustString()) + client, err := NewClient() if err != nil { t.Fatalf("NewClient() error: %v", err) } - resp, err = client.NewSimpleRequest("https://www.b612.me", "GET", + resp, err = client.NewSimpleRequest(localURL, "GET", + WithTLSConfig(&tls.Config{RootCAs: pool}), WithHeader("hello", "world"), WithContext(context.Background()), WithBearerToken("ddddddd")).Do() @@ -134,14 +163,24 @@ func TestRequestTls(t *testing.T) { } func TestTLSWithProxyPath(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("proxied")) + })) + defer server.Close() + + proxy := newIPv4ConnectProxyServer(t, nil) + defer proxy.Close() + client, err := NewClient() if err != nil { t.Fatal(err) } - req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", + req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET", WithTimeout(10*time.Second), - WithProxy("http://127.0.0.1:29992"), + WithProxy(proxy.URL), + WithTLSConfig(&tls.Config{RootCAs: pool}), ) if err != nil { t.Fatal(err) @@ -152,10 +191,22 @@ func TestTLSWithProxyPath(t *testing.T) { t.Fatalf("Do error: %v", err) } defer resp.Close() + if targets := proxy.Targets(); len(targets) != 1 { + t.Fatalf("proxy targets=%v; want 1 target", targets) + } t.Log(resp.Status) } func TestTLSWithProxyBug(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + proxy := newIPv4ConnectProxyServer(t, nil) + defer proxy.Close() + client, err := NewClient() if err != nil { t.Fatal(err) @@ -163,9 +214,11 @@ func TestTLSWithProxyBug(t *testing.T) { // 关键:使用 WithProxy 触发 needsDynamicTransport // 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支 - req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", + req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET", WithTimeout(10*time.Second), - WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport + WithProxy(proxy.URL), + WithCustomIP([]string{"127.0.0.1"}), + WithTLSConfig(&tls.Config{RootCAs: pool}), ) if err != nil { t.Fatal(err) @@ -177,20 +230,30 @@ func TestTLSWithProxyBug(t *testing.T) { t.Fatalf("Do error: %v", err) } defer resp.Close() + if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" { + t.Fatalf("proxy targets=%v", targets) + } t.Logf("Status: %s", resp.Status) } // 更精准的复现:直接测试有问题的分支 func TestTLSDialWithoutServerName(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + client, err := NewClient() if err != nil { t.Fatal(err) } // 使用 WithCustomIP 也能触发 defaultDialTLSFunc - req, err := client.NewRequest("https://www.google.com", "GET", + req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET", WithTimeout(10*time.Second), - WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP + WithCustomIP([]string{"127.0.0.1"}), + WithTLSConfig(&tls.Config{RootCAs: pool}), ) if err != nil { t.Fatal(err) @@ -206,14 +269,21 @@ func TestTLSDialWithoutServerName(t *testing.T) { // 最小复现:只要触发 needsDynamicTransport 即可 func TestMinimalTLSBug(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + client, err := NewClient() if err != nil { t.Fatal(err) } // WithDialTimeout 也会触发动态 transport - req, err := client.NewRequest("https://www.baidu.com", "GET", + req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET", WithDialTimeout(5*time.Second), + WithTLSConfig(&tls.Config{RootCAs: pool}), ) if err != nil { t.Fatal(err) @@ -227,3 +297,40 @@ func TestMinimalTLSBug(t *testing.T) { defer resp.Close() t.Logf("Status: %s", resp.Status) } + +func TestTLSWithSOCKS5ProxyPath(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + proxy := newSOCKS5ProxyServer(t, nil) + defer proxy.Close() + + client, err := NewClient() + if err != nil { + t.Fatal(err) + } + + req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET", + WithTimeout(10*time.Second), + WithProxy(proxy.URL()), + WithCustomIP([]string{"127.0.0.1"}), + WithTLSConfig(&tls.Config{RootCAs: pool}), + ) + if err != nil { + t.Fatal(err) + } + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + + if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" { + t.Fatalf("socks5 targets=%v", targets) + } + t.Logf("Status: %s", resp.Status) +} diff --git a/tlssniffer.go b/tlssniffer.go index acc58c7..f28ae04 100644 --- a/tlssniffer.go +++ b/tlssniffer.go @@ -4,12 +4,13 @@ import ( "bytes" "context" "crypto/tls" - "encoding/binary" "errors" "io" "net" "sync" "time" + + "b612.me/starnet/internal/tlssniffercore" ) // replayConn replays buffered bytes first, then reads from live conn. @@ -51,214 +52,35 @@ type TLSSniffer struct{} // Sniff detects TLS and extracts SNI when possible. func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { - if maxBytes <= 0 { - maxBytes = 64 * 1024 + res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes) + if err != nil { + return SniffResult{}, err } + return convertCoreSniffResult(res), nil +} - var buf bytes.Buffer - limited := &io.LimitedReader{R: conn, N: int64(maxBytes)} - meta, isTLS := sniffClientHello(limited, &buf, conn) - +func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult { out := SniffResult{ - IsTLS: isTLS, - Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)), + IsTLS: res.IsTLS, + Buffer: res.Buffer, } - if isTLS { - out.ClientHello = meta + if res.ClientHello != nil { + out.ClientHello = convertCoreClientHelloMeta(res.ClientHello) } - return out, nil + return out } -func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) { - meta := &ClientHelloMeta{ - LocalAddr: conn.LocalAddr(), - RemoteAddr: conn.RemoteAddr(), +func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta { + if meta == nil { + return nil } - - header, complete := readTLSRecordHeader(r, buf) - if len(header) < 3 { - return nil, false - } - isTLS := header[0] == 0x16 && header[1] == 0x03 - if !isTLS { - return nil, false - } - if len(header) < 5 || !complete { - return meta, true - } - - recordLen := int(binary.BigEndian.Uint16(header[3:5])) - recordBody, bodyOK := readBufferedBytes(r, buf, recordLen) - if !bodyOK { - return meta, true - } - if len(recordBody) < 4 || recordBody[0] != 0x01 { - return nil, false - } - - helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3]) - helloBytes := append([]byte(nil), recordBody[4:]...) - for len(helloBytes) < helloLen { - nextHeader, nextOK := readTLSRecordHeader(r, buf) - if len(nextHeader) < 5 || !nextOK { - return meta, true - } - if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 { - return meta, true - } - nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5])) - nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen) - if !nextBodyOK { - return meta, true - } - helloBytes = append(helloBytes, nextBody...) - } - - parseClientHelloBody(meta, helloBytes[:helloLen]) - return meta, true -} - -func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) { - return readBufferedBytes(r, buf, 5) -} - -func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) { - if n <= 0 { - return nil, true - } - tmp := make([]byte, n) - readN, err := io.ReadFull(r, tmp) - if readN > 0 { - buf.Write(tmp[:readN]) - } - return append([]byte(nil), tmp[:readN]...), err == nil -} - -func parseClientHelloBody(meta *ClientHelloMeta, body []byte) { - if meta == nil || len(body) < 34 { - return - } - - offset := 2 + 32 - sessionIDLen := int(body[offset]) - offset++ - if offset+sessionIDLen > len(body) { - return - } - offset += sessionIDLen - - if offset+2 > len(body) { - return - } - cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) - offset += 2 - if offset+cipherSuitesLen > len(body) { - return - } - for i := 0; i+1 < cipherSuitesLen; i += 2 { - meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2])) - } - offset += cipherSuitesLen - - if offset >= len(body) { - return - } - compressionMethodsLen := int(body[offset]) - offset++ - if offset+compressionMethodsLen > len(body) { - return - } - offset += compressionMethodsLen - - if offset+2 > len(body) { - return - } - extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) - offset += 2 - if offset+extensionsLen > len(body) { - return - } - - parseClientHelloExtensions(meta, body[offset:offset+extensionsLen]) -} - -func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) { - for offset := 0; offset+4 <= len(exts); { - extType := binary.BigEndian.Uint16(exts[offset : offset+2]) - extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4])) - offset += 4 - if offset+extLen > len(exts) { - return - } - extData := exts[offset : offset+extLen] - offset += extLen - - switch extType { - case 0: - parseServerNameExtension(meta, extData) - case 16: - parseALPNExtension(meta, extData) - case 43: - parseSupportedVersionsExtension(meta, extData) - } - } -} - -func parseServerNameExtension(meta *ClientHelloMeta, data []byte) { - if len(data) < 2 { - return - } - listLen := int(binary.BigEndian.Uint16(data[:2])) - if listLen == 0 || 2+listLen > len(data) { - return - } - list := data[2 : 2+listLen] - for offset := 0; offset+3 <= len(list); { - nameType := list[offset] - nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3])) - offset += 3 - if offset+nameLen > len(list) { - return - } - if nameType == 0 { - meta.ServerName = string(list[offset : offset+nameLen]) - return - } - offset += nameLen - } -} - -func parseALPNExtension(meta *ClientHelloMeta, data []byte) { - if len(data) < 2 { - return - } - listLen := int(binary.BigEndian.Uint16(data[:2])) - if listLen == 0 || 2+listLen > len(data) { - return - } - list := data[2 : 2+listLen] - for offset := 0; offset < len(list); { - nameLen := int(list[offset]) - offset++ - if offset+nameLen > len(list) { - return - } - meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen])) - offset += nameLen - } -} - -func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) { - if len(data) < 1 { - return - } - listLen := int(data[0]) - if listLen == 0 || 1+listLen > len(data) { - return - } - list := data[1 : 1+listLen] - for offset := 0; offset+1 < len(list); offset += 2 { - meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2])) + return &ClientHelloMeta{ + ServerName: meta.ServerName, + LocalAddr: meta.LocalAddr, + RemoteAddr: meta.RemoteAddr, + SupportedProtos: append([]string(nil), meta.SupportedProtos...), + SupportedVersions: append([]uint16(nil), meta.SupportedVersions...), + CipherSuites: append([]uint16(nil), meta.CipherSuites...), } } @@ -290,17 +112,17 @@ type Conn struct { func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { return &Conn{ - Conn: raw, - plainConn: raw, - baseTLSConfig: cfg.BaseTLSConfig, - getConfigForClient: cfg.GetConfigForClient, + Conn: raw, + plainConn: raw, + baseTLSConfig: cfg.BaseTLSConfig, + getConfigForClient: cfg.GetConfigForClient, getConfigForClientHello: cfg.GetConfigForClientHello, - allowNonTLS: cfg.AllowNonTLS, - sniffer: TLSSniffer{}, - sniffTimeout: cfg.SniffTimeout, - maxClientHello: cfg.MaxClientHelloBytes, - logger: cfg.Logger, - stats: stats, + allowNonTLS: cfg.AllowNonTLS, + sniffer: TLSSniffer{}, + sniffTimeout: cfg.SniffTimeout, + maxClientHello: cfg.MaxClientHelloBytes, + logger: cfg.Logger, + stats: stats, } } @@ -433,123 +255,11 @@ func (c *Conn) serverName() string { } func composeServerTLSConfig(base, selected *tls.Config) *tls.Config { - if base == nil { - return selected - } - if selected == nil { - return base - } - - out := base.Clone() - applyServerTLSOverrides(out, selected) - return out + return tlssniffercore.ComposeServerTLSConfig(base, selected) } func applyServerTLSOverrides(dst, src *tls.Config) { - if dst == nil || src == nil { - return - } - - if src.Rand != nil { - dst.Rand = src.Rand - } - if src.Time != nil { - dst.Time = src.Time - } - if len(src.Certificates) > 0 { - dst.Certificates = append([]tls.Certificate(nil), src.Certificates...) - } - if len(src.NameToCertificate) > 0 { - m := make(map[string]*tls.Certificate, len(src.NameToCertificate)) - for k, v := range src.NameToCertificate { - m[k] = v - } - dst.NameToCertificate = m - } - if src.GetCertificate != nil { - dst.GetCertificate = src.GetCertificate - } - if src.GetClientCertificate != nil { - dst.GetClientCertificate = src.GetClientCertificate - } - if src.GetConfigForClient != nil { - dst.GetConfigForClient = src.GetConfigForClient - } - if src.VerifyPeerCertificate != nil { - dst.VerifyPeerCertificate = src.VerifyPeerCertificate - } - if src.VerifyConnection != nil { - dst.VerifyConnection = src.VerifyConnection - } - if src.RootCAs != nil { - dst.RootCAs = src.RootCAs - } - if len(src.NextProtos) > 0 { - dst.NextProtos = append([]string(nil), src.NextProtos...) - } - if src.ServerName != "" { - dst.ServerName = src.ServerName - } - if src.ClientAuth > dst.ClientAuth { - dst.ClientAuth = src.ClientAuth - } - if src.ClientCAs != nil { - dst.ClientCAs = src.ClientCAs - } - if src.InsecureSkipVerify { - dst.InsecureSkipVerify = true - } - if len(src.CipherSuites) > 0 { - dst.CipherSuites = append([]uint16(nil), src.CipherSuites...) - } - if src.PreferServerCipherSuites { - dst.PreferServerCipherSuites = true - } - if src.SessionTicketsDisabled { - dst.SessionTicketsDisabled = true - } - if src.SessionTicketKey != ([32]byte{}) { - dst.SessionTicketKey = src.SessionTicketKey - } - if src.ClientSessionCache != nil { - dst.ClientSessionCache = src.ClientSessionCache - } - if src.UnwrapSession != nil { - dst.UnwrapSession = src.UnwrapSession - } - if src.WrapSession != nil { - dst.WrapSession = src.WrapSession - } - if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) { - dst.MinVersion = src.MinVersion - } - if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) { - dst.MaxVersion = src.MaxVersion - } - if len(src.CurvePreferences) > 0 { - dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...) - } - if src.DynamicRecordSizingDisabled { - dst.DynamicRecordSizingDisabled = true - } - if src.Renegotiation != 0 { - dst.Renegotiation = src.Renegotiation - } - if src.KeyLogWriter != nil { - dst.KeyLogWriter = src.KeyLogWriter - } - if len(src.EncryptedClientHelloConfigList) > 0 { - dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...) - } - if src.EncryptedClientHelloRejectionVerify != nil { - dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify - } - if src.GetEncryptedClientHelloKeys != nil { - dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys - } - if len(src.EncryptedClientHelloKeys) > 0 { - dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...) - } + tlssniffercore.ApplyServerTLSOverrides(dst, src) } func (c *Conn) IsTLS() bool { diff --git a/trace.go b/trace.go new file mode 100644 index 0000000..7b1153d --- /dev/null +++ b/trace.go @@ -0,0 +1,340 @@ +package starnet + +import ( + "context" + "crypto/tls" + "net" + "net/http/httptrace" + "sync/atomic" + "time" +) + +type traceContextKey struct{} + +// TraceHooks defines optional callbacks for network lifecycle events. +// Hooks may be called concurrently. +type TraceHooks struct { + GetConn func(TraceGetConnInfo) + GotConn func(TraceGotConnInfo) + PutIdleConn func(TracePutIdleConnInfo) + DNSStart func(TraceDNSStartInfo) + DNSDone func(TraceDNSDoneInfo) + ConnectStart func(TraceConnectStartInfo) + ConnectDone func(TraceConnectDoneInfo) + TLSHandshakeStart func(TraceTLSHandshakeStartInfo) + TLSHandshakeDone func(TraceTLSHandshakeDoneInfo) + WroteHeaderField func(TraceWroteHeaderFieldInfo) + WroteHeaders func() + WroteRequest func(TraceWroteRequestInfo) + GotFirstResponseByte func() + RetryAttemptStart func(TraceRetryAttemptStartInfo) + RetryAttemptDone func(TraceRetryAttemptDoneInfo) + RetryBackoff func(TraceRetryBackoffInfo) +} + +type TraceGetConnInfo struct { + Addr string +} + +type TraceGotConnInfo struct { + Conn net.Conn + Reused bool + WasIdle bool + IdleTime time.Duration +} + +type TracePutIdleConnInfo struct { + Err error +} + +type TraceDNSStartInfo struct { + Host string +} + +type TraceDNSDoneInfo struct { + Addrs []net.IPAddr + Coalesced bool + Err error +} + +type TraceConnectStartInfo struct { + Network string + Addr string +} + +type TraceConnectDoneInfo struct { + Network string + Addr string + Err error +} + +type TraceTLSHandshakeStartInfo struct { + Network string + Addr string + ServerName string +} + +type TraceTLSHandshakeDoneInfo struct { + Network string + Addr string + ServerName string + ConnectionState tls.ConnectionState + Err error +} + +type TraceWroteHeaderFieldInfo struct { + Key string + Values []string +} + +type TraceWroteRequestInfo struct { + Err error +} + +type TraceRetryAttemptStartInfo struct { + Attempt int + MaxAttempts int +} + +type TraceRetryAttemptDoneInfo struct { + Attempt int + MaxAttempts int + StatusCode int + Err error + WillRetry bool +} + +type TraceRetryBackoffInfo struct { + Attempt int + Delay time.Duration +} + +type traceState struct { + hooks *TraceHooks + customTLS atomic.Uint32 + manualDNSRefs atomic.Int32 +} + +func newTraceState(hooks *TraceHooks) *traceState { + if hooks == nil { + return nil + } + return &traceState{hooks: hooks} +} + +func withTraceState(ctx context.Context, state *traceState) context.Context { + if state == nil { + return ctx + } + return context.WithValue(ctx, traceContextKey{}, state) +} + +func getTraceState(ctx context.Context) *traceState { + if ctx == nil { + return nil + } + state, _ := ctx.Value(traceContextKey{}).(*traceState) + return state +} + +func (t *traceState) needsHTTPTrace() bool { + if t == nil || t.hooks == nil { + return false + } + h := t.hooks + return h.GetConn != nil || + h.GotConn != nil || + h.PutIdleConn != nil || + h.DNSStart != nil || + h.DNSDone != nil || + h.ConnectStart != nil || + h.ConnectDone != nil || + h.TLSHandshakeStart != nil || + h.TLSHandshakeDone != nil || + h.WroteHeaderField != nil || + h.WroteHeaders != nil || + h.WroteRequest != nil || + h.GotFirstResponseByte != nil +} + +func (t *traceState) clientTrace() *httptrace.ClientTrace { + if !t.needsHTTPTrace() { + return nil + } + + h := t.hooks + trace := &httptrace.ClientTrace{} + if h.GetConn != nil { + trace.GetConn = func(hostPort string) { + h.GetConn(TraceGetConnInfo{Addr: hostPort}) + } + } + if h.GotConn != nil { + trace.GotConn = func(info httptrace.GotConnInfo) { + h.GotConn(TraceGotConnInfo{ + Conn: info.Conn, + Reused: info.Reused, + WasIdle: info.WasIdle, + IdleTime: info.IdleTime, + }) + } + } + if h.PutIdleConn != nil { + trace.PutIdleConn = func(err error) { + h.PutIdleConn(TracePutIdleConnInfo{Err: err}) + } + } + if h.DNSStart != nil { + trace.DNSStart = func(info httptrace.DNSStartInfo) { + if t.usesManualDNS() { + return + } + h.DNSStart(TraceDNSStartInfo{Host: info.Host}) + } + } + if h.DNSDone != nil { + trace.DNSDone = func(info httptrace.DNSDoneInfo) { + if t.usesManualDNS() { + return + } + h.DNSDone(TraceDNSDoneInfo{ + Addrs: append([]net.IPAddr(nil), info.Addrs...), + Coalesced: info.Coalesced, + Err: info.Err, + }) + } + } + if h.ConnectStart != nil { + trace.ConnectStart = func(network, addr string) { + h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr}) + } + } + if h.ConnectDone != nil { + trace.ConnectDone = func(network, addr string, err error) { + h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err}) + } + } + if h.TLSHandshakeStart != nil { + trace.TLSHandshakeStart = func() { + if t.usesCustomTLS() { + return + } + h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{}) + } + } + if h.TLSHandshakeDone != nil { + trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) { + if t.usesCustomTLS() { + return + } + h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{ + ConnectionState: state, + Err: err, + }) + } + } + if h.WroteHeaderField != nil { + trace.WroteHeaderField = func(key string, value []string) { + h.WroteHeaderField(TraceWroteHeaderFieldInfo{ + Key: key, + Values: value, + }) + } + } + if h.WroteHeaders != nil { + trace.WroteHeaders = h.WroteHeaders + } + if h.WroteRequest != nil { + trace.WroteRequest = func(info httptrace.WroteRequestInfo) { + h.WroteRequest(TraceWroteRequestInfo{Err: info.Err}) + } + } + if h.GotFirstResponseByte != nil { + trace.GotFirstResponseByte = h.GotFirstResponseByte + } + return trace +} + +func (t *traceState) markCustomTLS() { + if t == nil { + return + } + t.customTLS.Store(1) +} + +func (t *traceState) usesCustomTLS() bool { + if t == nil { + return false + } + return t.customTLS.Load() != 0 +} + +func (t *traceState) beginManualDNS() { + if t == nil { + return + } + t.manualDNSRefs.Add(1) +} + +func (t *traceState) endManualDNS() { + if t == nil { + return + } + t.manualDNSRefs.Add(-1) +} + +func (t *traceState) usesManualDNS() bool { + if t == nil { + return false + } + return t.manualDNSRefs.Load() > 0 +} + +func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) { + if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil { + return + } + t.hooks.TLSHandshakeStart(info) +} + +func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) { + if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil { + return + } + t.hooks.TLSHandshakeDone(info) +} + +func (t *traceState) dnsStart(info TraceDNSStartInfo) { + if t == nil || t.hooks == nil || t.hooks.DNSStart == nil { + return + } + t.hooks.DNSStart(info) +} + +func (t *traceState) dnsDone(info TraceDNSDoneInfo) { + if t == nil || t.hooks == nil || t.hooks.DNSDone == nil { + return + } + t.hooks.DNSDone(info) +} + +func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) { + if hooks == nil || hooks.RetryAttemptStart == nil { + return + } + hooks.RetryAttemptStart(info) +} + +func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) { + if hooks == nil || hooks.RetryAttemptDone == nil { + return + } + hooks.RetryAttemptDone(info) +} + +func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) { + if hooks == nil || hooks.RetryBackoff == nil { + return + } + hooks.RetryBackoff(info) +} diff --git a/trace_test.go b/trace_test.go new file mode 100644 index 0000000..f2ec128 --- /dev/null +++ b/trace_test.go @@ -0,0 +1,324 @@ +package starnet + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "testing" + "time" +) + +func TestTraceHooksStandardHTTPSPath(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + var mu sync.Mutex + events := map[string]int{} + hooks := &TraceHooks{ + GetConn: func(info TraceGetConnInfo) { + mu.Lock() + events["get_conn"]++ + mu.Unlock() + }, + GotConn: func(info TraceGotConnInfo) { + mu.Lock() + events["got_conn"]++ + mu.Unlock() + }, + TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { + mu.Lock() + events["tls_start"]++ + mu.Unlock() + }, + TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { + mu.Lock() + events["tls_done"]++ + mu.Unlock() + if info.Err != nil { + t.Errorf("unexpected tls handshake error: %v", info.Err) + } + }, + WroteHeaders: func() { + mu.Lock() + events["wrote_headers"]++ + mu.Unlock() + }, + WroteRequest: func(info TraceWroteRequestInfo) { + mu.Lock() + events["wrote_request"]++ + mu.Unlock() + if info.Err != nil { + t.Errorf("unexpected write error: %v", info.Err) + } + }, + GotFirstResponseByte: func() { + mu.Lock() + events["first_byte"]++ + mu.Unlock() + }, + } + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetSkipTLSVerify(true). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} { + if events[key] == 0 { + t.Fatalf("expected trace event %q", key) + } + } +} + +func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var mu sync.Mutex + tlsStartCount := 0 + tlsDoneCount := 0 + var lastInfo TraceTLSHandshakeDoneInfo + hooks := &TraceHooks{ + TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { + mu.Lock() + tlsStartCount++ + mu.Unlock() + }, + TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { + mu.Lock() + tlsDoneCount++ + lastInfo = info + mu.Unlock() + }, + } + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetSkipTLSVerify(true). + SetDialTimeout(1500 * time.Millisecond). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if tlsStartCount != 1 { + t.Fatalf("tlsStartCount=%d", tlsStartCount) + } + if tlsDoneCount != 1 { + t.Fatalf("tlsDoneCount=%d", tlsDoneCount) + } + if lastInfo.Err != nil { + t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err) + } + if lastInfo.ConnectionState.Version == 0 { + t.Fatal("expected tls connection state") + } +} + +func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr() error: %v", err) + } + + var mu sync.Mutex + dnsStartCount := 0 + dnsDoneCount := 0 + var dnsStartHost string + hooks := &TraceHooks{ + DNSStart: func(info TraceDNSStartInfo) { + mu.Lock() + dnsStartCount++ + dnsStartHost = info.Host + mu.Unlock() + }, + DNSDone: func(info TraceDNSDoneInfo) { + mu.Lock() + dnsDoneCount++ + mu.Unlock() + if info.Err != nil { + t.Errorf("unexpected dns error: %v", info.Err) + } + }, + } + + url := "http://trace.example.test:" + strconv.Itoa(addr.Port) + resp, err := NewSimpleRequest(url, http.MethodGet). + SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: addr.IP}}, nil + }). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if dnsStartCount != 1 { + t.Fatalf("dnsStartCount=%d", dnsStartCount) + } + if dnsDoneCount != 1 { + t.Fatalf("dnsDoneCount=%d", dnsDoneCount) + } + if dnsStartHost != "trace.example.test" { + t.Fatalf("dnsStartHost=%q", dnsStartHost) + } +} + +func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var mu sync.Mutex + connectStartCount := 0 + connectDoneCount := 0 + hooks := &TraceHooks{ + ConnectStart: func(info TraceConnectStartInfo) { + mu.Lock() + connectStartCount++ + mu.Unlock() + }, + ConnectDone: func(info TraceConnectDoneInfo) { + mu.Lock() + connectDoneCount++ + mu.Unlock() + if info.Err != nil { + t.Errorf("unexpected connect error: %v", info.Err) + } + }, + } + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(context.Background(), network, addr) + }). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if connectStartCount != 1 { + t.Fatalf("connectStartCount=%d", connectStartCount) + } + if connectDoneCount != 1 { + t.Fatalf("connectDoneCount=%d", connectDoneCount) + } +} + +func TestTraceHooksRetryEvents(t *testing.T) { + var hits int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + if hits == 1 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + var mu sync.Mutex + starts := 0 + dones := 0 + backoffs := 0 + var finalDone TraceRetryAttemptDoneInfo + hooks := &TraceHooks{ + RetryAttemptStart: func(info TraceRetryAttemptStartInfo) { + mu.Lock() + starts++ + mu.Unlock() + }, + RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) { + mu.Lock() + dones++ + finalDone = info + mu.Unlock() + }, + RetryBackoff: func(info TraceRetryBackoffInfo) { + mu.Lock() + backoffs++ + mu.Unlock() + }, + } + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)). + SetTraceHooks(hooks). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + mu.Lock() + defer mu.Unlock() + if starts != 2 { + t.Fatalf("starts=%d", starts) + } + if dones != 2 { + t.Fatalf("dones=%d", dones) + } + if backoffs != 1 { + t.Fatalf("backoffs=%d", backoffs) + } + if finalDone.WillRetry { + t.Fatal("expected final attempt not to retry") + } + if finalDone.StatusCode != http.StatusOK { + t.Fatalf("final status=%d", finalDone.StatusCode) + } +} + +func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) { + var gotErr error + hooks := &TraceHooks{ + DNSDone: func(info TraceDNSDoneInfo) { + gotErr = info.Err + }, + } + + _, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet). + SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { + return nil, errors.New("lookup failed") + }). + SetTraceHooks(hooks). + Do() + if err == nil { + t.Fatal("expected request error") + } + if gotErr == nil || gotErr.Error() != "lookup failed" { + t.Fatalf("gotErr=%v", gotErr) + } +} diff --git a/transport.go b/transport.go index 504a9d9..27a0774 100644 --- a/transport.go +++ b/transport.go @@ -1,61 +1,220 @@ package starnet import ( + "context" + "crypto/tls" + "net" "net/http" "net/url" + "strings" "sync" "time" ) +const dynamicTransportCacheMaxEntries = 64 + +type dynamicTransportCacheKey struct { + proxyKey string + dialTimeout time.Duration + customIPs string + customDNS string + tlsServerName string + skipVerify bool +} + // Transport 自定义 Transport(支持请求级配置) type Transport struct { - base *http.Transport - mu sync.RWMutex + base *http.Transport + dynamicCache map[dynamicTransportCacheKey]*http.Transport + dynamicCacheOrder []dynamicTransportCacheKey + mu sync.RWMutex } // RoundTrip 实现 http.RoundTripper 接口 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - // 确保 base 已初始化 - if t.base == nil { - t.mu.Lock() - if t.base == nil { - t.base = &http.Transport{ - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - } - t.mu.Unlock() - } + t.ensureBase() // 提取请求级别的配置 reqCtx := getRequestContext(req.Context()) + traceState := getTraceState(req.Context()) + execReq := req + execReqCtx := reqCtx + var targetAddrs []string // 优先级1:完全自定义的 transport - if reqCtx.Transport != nil { - return reqCtx.Transport.RoundTrip(req) + if execReqCtx.Transport != nil { + return execReqCtx.Transport.RoundTrip(execReq) + } + + var err error + execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState) + if err != nil { + return nil, err } // 优先级2:需要动态配置 - if needsDynamicTransport(reqCtx) { - dynamicTransport := t.buildDynamicTransport(reqCtx) - return dynamicTransport.RoundTrip(req) + if needsDynamicTransport(execReqCtx) { + dynamicTransport := t.getDynamicTransport(execReqCtx, traceState) + if len(targetAddrs) > 0 { + return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs) + } + return dynamicTransport.RoundTrip(execReq) } // 优先级3:使用基础 transport t.mu.RLock() - defer t.mu.RUnlock() - return t.base.RoundTrip(req) + baseTransport := t.base + t.mu.RUnlock() + if len(targetAddrs) > 0 { + return roundTripResolvedTargets(baseTransport, execReq, targetAddrs) + } + return baseTransport.RoundTrip(execReq) +} + +func newBaseHTTPTransport() *http.Transport { + return &http.Transport{ + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} + +func (t *Transport) ensureBase() { + if t.base != nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.ensureBaseLocked() +} + +func (t *Transport) ensureBaseLocked() { + if t.base == nil { + t.base = newBaseHTTPTransport() + } +} + +func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport { + if key, ok := newDynamicTransportCacheKey(rc); ok { + return t.getOrCreateCachedDynamicTransport(key, rc) + } + return t.buildDynamicTransport(rc, traceState) +} + +func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport { + t.mu.RLock() + if transport := t.dynamicCache[key]; transport != nil { + t.mu.RUnlock() + return transport + } + t.mu.RUnlock() + + t.mu.Lock() + defer t.mu.Unlock() + + t.ensureBaseLocked() + if transport := t.dynamicCache[key]; transport != nil { + return transport + } + + transport := buildDynamicTransportFromBase(t.base, rc, nil) + if t.dynamicCache == nil { + t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport) + } + if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries { + oldestKey := t.dynamicCacheOrder[0] + t.dynamicCacheOrder = t.dynamicCacheOrder[1:] + if oldest := t.dynamicCache[oldestKey]; oldest != nil { + oldest.CloseIdleConnections() + delete(t.dynamicCache, oldestKey) + } + } + t.dynamicCache[key] = transport + t.dynamicCacheOrder = append(t.dynamicCacheOrder, key) + return transport +} + +func (t *Transport) resetDynamicTransportCacheLocked() { + for _, key := range t.dynamicCacheOrder { + if transport := t.dynamicCache[key]; transport != nil { + transport.CloseIdleConnections() + } + } + t.dynamicCache = nil + t.dynamicCacheOrder = nil +} + +func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) { + if rc == nil { + return dynamicTransportCacheKey{}, false + } + if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil { + return dynamicTransportCacheKey{}, false + } + if rc.TLSConfig != nil && !rc.TLSConfigCacheable { + return dynamicTransportCacheKey{}, false + } + + key := dynamicTransportCacheKey{ + proxyKey: normalizeProxyCacheKey(rc.Proxy), + dialTimeout: rc.DialTimeout, + customIPs: serializeTransportCacheList(rc.CustomIP), + customDNS: serializeTransportCacheList(rc.CustomDNS), + tlsServerName: effectiveTLSServerName(rc), + } + if rc.TLSConfig != nil { + key.skipVerify = rc.TLSConfig.InsecureSkipVerify + } + return key, true +} + +func normalizeProxyCacheKey(proxy string) string { + if proxy == "" { + return "" + } + proxyURL, err := parseProxyURL(proxy) + if err != nil { + return "\x00invalid:" + proxy + } + return proxyURL.String() +} + +func serializeTransportCacheList(values []string) string { + if len(values) == 0 { + return "" + } + var builder strings.Builder + for _, value := range values { + builder.WriteString(value) + builder.WriteByte(0) + } + return builder.String() +} + +func effectiveTLSServerName(rc *RequestContext) string { + if rc == nil { + return "" + } + if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" { + return rc.TLSConfig.ServerName + } + return rc.TLSServerName } // buildDynamicTransport 构建动态 Transport -func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport { +func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport { + t.ensureBase() t.mu.RLock() - transport := t.base.Clone() + baseTransport := t.base t.mu.RUnlock() + return buildDynamicTransportFromBase(baseTransport, rc, traceState) +} + +func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport { + transport := baseTransport.Clone() // 应用 TLS 配置(即使为 nil 也要检查 SkipVerify) if rc.TLSConfig != nil { @@ -64,15 +223,33 @@ func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport { // 应用代理配置 if rc.Proxy != "" { - proxyURL, err := url.Parse(rc.Proxy) - if err == nil { + proxyURL, err := parseProxyURL(rc.Proxy) + if err != nil { + transport.Proxy = func(*http.Request) (*url.URL, error) { + return nil, err + } + } else { transport.Proxy = http.ProxyURL(proxyURL) } } // 应用自定义 Dial 函数 if rc.DialFn != nil { - transport.DialContext = rc.DialFn + if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) { + dialFn := rc.DialFn + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + if traceState.hooks.ConnectStart != nil { + traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr}) + } + conn, err := dialFn(ctx, network, addr) + if traceState.hooks.ConnectDone != nil { + traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err}) + } + return conn, err + } + } else { + transport.DialContext = rc.DialFn + } } else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil { // 使用默认 Dial 函数(会从 context 读取配置) transport.DialContext = defaultDialFunc @@ -93,5 +270,147 @@ func (t *Transport) Base() *http.Transport { func (t *Transport) SetBase(base *http.Transport) { t.mu.Lock() t.base = base + t.resetDynamicTransportCacheLocked() t.mu.Unlock() } + +func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) { + if req == nil || req.URL == nil || reqCtx == nil { + return req, reqCtx, nil, nil + } + if reqCtx.Proxy == "" || reqCtx.DialFn != nil { + return req, reqCtx, nil, nil + } + if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil { + return req, reqCtx, nil, nil + } + + host := req.URL.Hostname() + if host == "" { + return req, reqCtx, nil, nil + } + + targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState) + if err != nil { + return nil, nil, nil, err + } + if len(targetAddrs) == 0 { + return req, reqCtx, nil, nil + } + + execReqCtx := *reqCtx + execReqCtx.CustomIP = nil + execReqCtx.CustomDNS = nil + execReqCtx.LookupIPFn = nil + + if req.URL.Scheme == "https" { + execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host) + if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil { + execReqCtx.TLSConfigCacheable = true + } + } + + execCtx := clearTargetResolutionContext(req.Context()) + execReq := req.Clone(execCtx) + execReq.Host = req.Host + if len(targetAddrs) == 1 { + execReq.URL.Host = targetAddrs[0] + return execReq, &execReqCtx, nil, nil + } + + return execReq, &execReqCtx, targetAddrs, nil +} + +func clearTargetResolutionContext(ctx context.Context) context.Context { + if v := ctx.Value(ctxKeyRequestContext); v != nil { + if rc, ok := v.(*RequestContext); ok && rc != nil { + cloned := cloneRequestContext(rc) + cloned.CustomIP = nil + cloned.CustomDNS = nil + cloned.LookupIPFn = nil + ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned) + } + } + ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil)) + ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil)) + ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil)) + return ctx +} + +func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config { + if serverName == "" { + return cfg + } + if cfg != nil { + if cfg.ServerName != "" { + return cfg + } + cloned := cfg.Clone() + cloned.ServerName = serverName + return cloned + } + return &tls.Config{ + ServerName: serverName, + NextProtos: []string{"h2", "http/1.1"}, + } +} + +func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) { + if rt == nil || baseReq == nil || len(targetAddrs) == 0 { + return rt.RoundTrip(baseReq) + } + + if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 { + targetAddrs = targetAddrs[:1] + } + + var lastErr error + for _, targetAddr := range targetAddrs { + attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr) + if err != nil { + return nil, err + } + + resp, err := rt.RoundTrip(attemptReq) + if err == nil { + return resp, nil + } + lastErr = err + } + + return nil, lastErr +} + +func requestAllowsResolvedTargetFallback(req *http.Request) bool { + if req == nil { + return false + } + if !isIdempotentMethod(req.Method) { + return false + } + if req.Body == nil || req.Body == http.NoBody { + return true + } + return req.GetBody != nil +} + +func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) { + req := baseReq.Clone(baseReq.Context()) + + switch { + case baseReq.Body == nil || baseReq.Body == http.NoBody: + req.Body = baseReq.Body + case baseReq.GetBody != nil: + body, err := baseReq.GetBody() + if err != nil { + return nil, wrapError(err, "clone request body for resolved target") + } + req.Body = body + default: + req.Body = baseReq.Body + } + + req.URL.Host = targetAddr + req.Host = baseReq.Host + return req, nil +} diff --git a/transport_cache_test.go b/transport_cache_test.go new file mode 100644 index 0000000..00ebc02 --- /dev/null +++ b/transport_cache_test.go @@ -0,0 +1,224 @@ +package starnet + +import ( + "crypto/tls" + "net" + "net/http" + "strconv" + "sync" + "testing" + "time" +) + +func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) { + transport := &Transport{base: newBaseHTTPTransport()} + + first := transport.getDynamicTransport(&RequestContext{ + Proxy: "http://127.0.0.1:8080", + DialTimeout: 2 * time.Second, + CustomIP: []string{"127.0.0.1"}, + TLSServerName: "cache.test", + }, nil) + second := transport.getDynamicTransport(&RequestContext{ + Proxy: "http://127.0.0.1:8080", + DialTimeout: 2 * time.Second, + CustomIP: []string{"127.0.0.1"}, + TLSServerName: "cache.test", + }, nil) + + if first != second { + t.Fatal("expected cached dynamic transport to be reused") + } + if got := len(transport.dynamicCache); got != 1 { + t.Fatalf("dynamic cache size=%d; want 1", got) + } +} + +func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) { + transport := &Transport{base: newBaseHTTPTransport()} + + first := transport.getDynamicTransport(&RequestContext{ + CustomIP: []string{"127.0.0.1"}, + TLSServerName: "first.test", + }, nil) + second := transport.getDynamicTransport(&RequestContext{ + CustomIP: []string{"127.0.0.1"}, + TLSServerName: "second.test", + }, nil) + + if first == second { + t.Fatal("expected distinct tls server names to use different transports") + } + if got := len(transport.dynamicCache); got != 2 { + t.Fatalf("dynamic cache size=%d; want 2", got) + } +} + +func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) { + transport := &Transport{base: newBaseHTTPTransport()} + reqCtx := &RequestContext{ + CustomIP: []string{"127.0.0.1"}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + + first := transport.getDynamicTransport(reqCtx, nil) + second := transport.getDynamicTransport(reqCtx, nil) + + if first == second { + t.Fatal("expected user tls config to bypass dynamic transport cache") + } + if got := len(transport.dynamicCache); got != 0 { + t.Fatalf("dynamic cache size=%d; want 0", got) + } +} + +func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) { + client := NewClientNoErr() + transport, ok := client.HTTPClient().Transport.(*Transport) + if !ok { + t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport) + } + + reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}} + first := transport.getDynamicTransport(reqCtx, nil) + if got := len(transport.dynamicCache); got != 1 { + t.Fatalf("dynamic cache size=%d; want 1 before reset", got) + } + + client.SetDefaultSkipTLSVerify(true) + if got := len(transport.dynamicCache); got != 0 { + t.Fatalf("dynamic cache size=%d; want 0 after reset", got) + } + + second := transport.getDynamicTransport(reqCtx, nil) + if first == second { + t.Fatal("expected cache reset after default tls change") + } +} + +func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr() error: %v", err) + } + + client := NewClientNoErr() + targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port) + + runRequest := func() bool { + var ( + mu sync.Mutex + gotConn bool + reused bool + ) + + resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). + SetCustomIP([]string{"127.0.0.1"}). + SetTraceHooks(&TraceHooks{ + GotConn: func(info TraceGotConnInfo) { + mu.Lock() + gotConn = true + reused = info.Reused + mu.Unlock() + }, + }). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if !gotConn { + t.Fatal("expected GotConn trace event") + } + return reused + } + + if runRequest() { + t.Fatal("first request unexpectedly reused a connection") + } + if !runRequest() { + t.Fatal("second request did not reuse cached dynamic transport connection") + } + + transport, ok := client.HTTPClient().Transport.(*Transport) + if !ok { + t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport) + } + if got := len(transport.dynamicCache); got != 1 { + t.Fatalf("dynamic cache size=%d; want 1", got) + } +} + +func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + req.Host = req.URL.Host + + execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{ + Proxy: "http://127.0.0.1:8080", + CustomIP: []string{"127.0.0.1"}, + }, nil) + if err != nil { + t.Fatalf("prepareProxyTargetRequest() error: %v", err) + } + + if execReq == req { + t.Fatal("expected cloned request for proxy target preparation") + } + if got := execReq.URL.Host; got != "127.0.0.1:8443" { + t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443") + } + if got := req.URL.Host; got != "proxy-single.test:8443" { + t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443") + } + if len(targetAddrs) != 0 { + t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs) + } + if execReqCtx == nil || execReqCtx.TLSConfig == nil { + t.Fatal("expected synthesized tls config for single target proxy request") + } + if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" { + t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test") + } +} + +func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + req.Host = req.URL.Host + + execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{ + Proxy: "http://127.0.0.1:8080", + CustomIP: []string{"127.0.0.1", "127.0.0.2"}, + }, nil) + if err != nil { + t.Fatalf("prepareProxyTargetRequest() error: %v", err) + } + + if got := execReq.URL.Host; got != "proxy-multi.test:9443" { + t.Fatalf("execReq.URL.Host=%q; want original host", got) + } + if len(targetAddrs) != 2 { + t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs) + } + if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" { + t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs) + } +} diff --git a/types.go b/types.go index 425e763..e0e43f5 100644 --- a/types.go +++ b/types.go @@ -53,6 +53,7 @@ type NetworkConfig struct { type TLSConfig struct { Config *tls.Config // TLS 配置 SkipVerify bool // 跳过证书验证 + ServerName string // 显式 TLS ServerName/SNI 覆盖 } // DNSConfig DNS 配置 @@ -62,8 +63,19 @@ type DNSConfig struct { LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数 } +type bodyMode uint8 + +const ( + bodyModeUnset bodyMode = iota + bodyModeBytes + bodyModeReader + bodyModeForm + bodyModeMultipart +) + // BodyConfig 请求体配置 type BodyConfig struct { + Mode bodyMode // 当前 body 来源模式 Bytes []byte // 原始字节 Reader io.Reader // 数据流 FormData map[string][]string // 表单数据 @@ -82,6 +94,7 @@ type RequestConfig struct { // 其他配置 BasicAuth [2]string // Basic 认证 + Host string // 显式 Host 头覆盖 ContentLength int64 // 手动设置的 Content-Length AutoCalcContentLength bool // 自动计算 Content-Length MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制) @@ -104,6 +117,7 @@ func (c *RequestConfig) Clone() *RequestConfig { TLS: TLSConfig{ Config: cloneTLSConfig(c.TLS.Config), SkipVerify: c.TLS.SkipVerify, + ServerName: c.TLS.ServerName, }, DNS: DNSConfig{ CustomIP: cloneStringSlice(c.DNS.CustomIP), @@ -111,6 +125,7 @@ func (c *RequestConfig) Clone() *RequestConfig { LookupFunc: c.DNS.LookupFunc, }, Body: BodyConfig{ + Mode: c.Body.Mode, Bytes: cloneBytes(c.Body.Bytes), Reader: c.Body.Reader, // Reader 不可克隆 FormData: cloneStringMapSlice(c.Body.FormData), @@ -120,6 +135,7 @@ func (c *RequestConfig) Clone() *RequestConfig { Cookies: cloneCookies(c.Cookies), Queries: cloneStringMapSlice(c.Queries), BasicAuth: c.BasicAuth, + Host: c.Host, ContentLength: c.ContentLength, AutoCalcContentLength: c.AutoCalcContentLength, MaxRespBodyBytes: c.MaxRespBodyBytes, diff --git a/utils.go b/utils.go index 5ffc470..886b0c0 100644 --- a/utils.go +++ b/utils.go @@ -101,24 +101,31 @@ func cloneCookies(cookies []*http.Cookie) []*http.Cookie { } newCookies := make([]*http.Cookie, len(cookies)) for i, c := range cookies { - newCookies[i] = &http.Cookie{ - Name: c.Name, - Value: c.Value, - Path: c.Path, - Domain: c.Domain, - Expires: c.Expires, - RawExpires: c.RawExpires, - MaxAge: c.MaxAge, - Secure: c.Secure, - HttpOnly: c.HttpOnly, - SameSite: c.SameSite, - Raw: c.Raw, - Unparsed: append([]string(nil), c.Unparsed...), - } + newCookies[i] = cloneCookie(c) } return newCookies } +func cloneCookie(cookie *http.Cookie) *http.Cookie { + if cookie == nil { + return nil + } + return &http.Cookie{ + Name: cookie.Name, + Value: cookie.Value, + Path: cookie.Path, + Domain: cookie.Domain, + Expires: cookie.Expires, + RawExpires: cookie.RawExpires, + MaxAge: cookie.MaxAge, + Secure: cookie.Secure, + HttpOnly: cookie.HttpOnly, + SameSite: cookie.SameSite, + Raw: cookie.Raw, + Unparsed: append([]string(nil), cookie.Unparsed...), + } +} + // cloneStringMapSlice 克隆 map[string][]string func cloneStringMapSlice(m map[string][]string) map[string][]string { if m == nil { @@ -171,8 +178,8 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { // copyWithProgress 带进度的复制 func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) { - if progress == nil { - return io.Copy(dst, src) + if ctx == nil { + ctx = context.Background() } var written int64 @@ -190,8 +197,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam nw, ew := dst.Write(buf[:nr]) if nw > 0 { written += int64(nw) - // 同步调用进度回调(不使用 goroutine) - progress(filename, written, total) + if progress != nil { + // 同步调用进度回调(不使用 goroutine) + progress(filename, written, total) + } } if ew != nil { return written, ew @@ -202,8 +211,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam } if err != nil { if err == io.EOF { - // 最后一次进度回调 - progress(filename, written, total) + if progress != nil { + // 最后一次进度回调 + progress(filename, written, total) + } return written, nil } return written, err