fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界

- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
This commit is contained in:
兔子 2026-04-19 15:39:51 +08:00
parent 9ac9b65bc5
commit 732e81316c
Signed by: b612
GPG Key ID: 99DD2222B612B612
43 changed files with 5633 additions and 1728 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@
.sentrux/
agent_readme.md
target.md
agents.md
.codex

View File

@ -1,58 +1,58 @@
# starnet
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities.
`starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。
## Highlights
## 功能概览
- Request-level timeout by context (without mutating shared `http.Client` timeout)
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses
- Response body safety guard via max body bytes limit
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`)
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios
- ICMP ping with IPv4/IPv6 target handling and option-based probing API
- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
- 响应体大小限制,避免一次性读取过大内容
- 错误分类辅助:`ClassifyError``IsTimeout``IsDNS``IsTLS``IsProxy``IsCanceled`
- TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景
- ICMP Ping支持 IPv4 / IPv6 目标和选项化探测
## Main Features
## 主要能力
### HTTP Client and Request
### HTTP 客户端与请求构建
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect`
- Request body helpers: JSON, form data, multipart file upload, stream body
- Header/cookie/query helpers with defensive copy on key setters
- Request cloning for safe reuse in concurrent or variant calls
- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
- 支持 `Get``Post``Put``Delete``Head``Patch``Options``Trace``Connect`
- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
### Timeout and Retry
### 超时与重试
- Request timeout is applied by context deadline, not global client timeout
- Retry supports:
- max attempts
- backoff factor/base/max
- jitter
- retry status whitelist
- idempotent-only guard
- custom retry-on-error callback
- Retry keeps original request pointer in final response for consistency
- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
- 重试支持:
- 最大尝试次数
- 基础退避、最大退避和退避因子
- 抖动比例
- 可重试状态码集合
- 仅幂等方法重试
- 自定义错误判定函数
- 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用
### Response Handling
### 响应处理
- `Bytes/String/JSON/Reader` helpers
- optional auto-fetch mode
- configurable max response body bytes to prevent oversized reads
- 提供 `Bytes``String``JSON``Reader` 等响应体读取接口
- 支持自动预取响应体
- 支持按字节数限制响应体读取上限
### Ping Module
### Ping 模块
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable`
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size
- explicit error semantics for permission/protocol/timeout/resolve failures
- 提供 `Ping``PingWithContext``Pingable` 以及兼容函数 `IsIpPingable`
- `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数
- 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义
## Install
## 安装
```bash
go get b612.me/starnet
```
## Quick Example
## 快速示例
```go
package main
@ -94,13 +94,18 @@ func main() {
}
```
## Stability Notes
## 行为说明
- Raw ICMP ping may require elevated privileges on some systems.
- Integration tests that rely on external network are environment-dependent.
- `NewClient``NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader``WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
## License
## 稳定性说明
This project is licensed under the Apache License 2.0.
See [LICENSE](./LICENSE).
- 原始 ICMP Ping 在部分系统上需要额外权限。
- 依赖外部网络环境的集成测试结果可能受运行环境影响。
## 许可证
本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。

View File

@ -148,6 +148,33 @@ func BenchmarkRequestCreation(b *testing.B) {
}
}
func BenchmarkRequestPrepareDefaultPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET")
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkRequestPrepareDynamicPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET",
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkResponseBodyRead(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test response data"))

View File

@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"sync"
"time"
)
// Client HTTP 客户端封装
@ -19,14 +18,7 @@ type Client struct {
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
baseTransport := newBaseHTTPTransport()
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
@ -40,6 +32,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
if err != nil {
return nil, wrapError(err, "create client")
}
if req.err != nil {
return nil, wrapError(req.err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
@ -61,7 +56,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
}, nil
}
// NewClientNoErr 创建新的 Client忽略错误
// NewClientNoErr 创建新的 Client忽略错误
// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象;
// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
@ -172,11 +169,13 @@ func (c *Client) Clone() *Client {
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
transport.ensureBaseLocked()
if tlsConfig != nil {
transport.base.TLSClientConfig = tlsConfig.Clone()
} else {
transport.base.TLSClientConfig = nil
}
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock()
}
return c
@ -186,12 +185,14 @@ func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
transport.ensureBaseLocked()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
} else {
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock()
}
return c
@ -227,6 +228,9 @@ func (c *Client) NewRequestWithContext(ctx context.Context, url, method string,
if err != nil {
return nil, err
}
if req.err != nil {
return nil, req.err
}
req.client = c
req.httpClient = c.client

View File

@ -14,6 +14,8 @@ type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyTLSConfigCacheable
ctxKeyTLSServerName
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
@ -21,12 +23,15 @@ const (
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
ctxKeyRequestContext
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
TLSConfigCacheable bool
TLSServerName string
Proxy string
CustomIP []string
CustomDNS []string
@ -36,43 +41,77 @@ type RequestContext struct {
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
var emptyRequestContext = &RequestContext{}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{}
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
return rc
}
}
var rc *RequestContext
ensure := func() *RequestContext {
if rc == nil {
rc = &RequestContext{}
}
return rc
}
if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport)
ensure().Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config)
ensure().TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
ensure().TLSConfigCacheable, _ = v.(bool)
}
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
ensure().TLSServerName, _ = v.(string)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string)
ensure().Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string)
ensure().CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string)
ensure().CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration)
ensure().DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration)
ensure().Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
if rc == nil {
return emptyRequestContext
}
return rc
}
func cloneRequestContext(rc *RequestContext) *RequestContext {
if rc == nil {
return nil
}
cloned := *rc
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
return &cloned
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
if rc == nil {
return false
}
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
@ -83,63 +122,67 @@ func needsDynamicTransport(rc *RequestContext) bool {
rc.LookupIPFn != nil
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
execCtx := ctx
func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
if config == nil {
return nil
}
rc := &RequestContext{
DialTimeout: config.Network.DialTimeout,
Timeout: config.Network.Timeout,
}
// 处理 TLS 配置
var tlsConfig *tls.Config
tlsConfigCacheable := false
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
} else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
tlsConfigCacheable = true
}
if config.TLS.SkipVerify && tlsConfig != nil {
tlsConfig.InsecureSkipVerify = true
}
if config.TLS.ServerName != "" && tlsConfig != nil {
tlsConfig.ServerName = config.TLS.ServerName
}
if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
rc.TLSConfig = tlsConfig
rc.TLSConfigCacheable = tlsConfigCacheable
}
if config.TLS.ServerName != "" {
rc.TLSServerName = config.TLS.ServerName
} else if defaultTLSServerName != "" {
rc.TLSServerName = defaultTLSServerName
}
// 注入代理
if config.Network.Proxy != "" {
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
}
rc.Proxy = config.Network.Proxy
rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
rc.LookupIPFn = config.DNS.LookupFunc
rc.DialFn = config.Network.DialFunc
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
rc.Transport = config.Transport
}
return execCtx
if !needsDynamicTransport(rc) {
return nil
}
return rc
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
rc := buildRequestContext(config, defaultTLSServerName)
if rc == nil {
return ctx
}
return context.WithValue(ctx, ctxKeyRequestContext, rc)
}

View File

@ -57,3 +57,55 @@ func TestSetFormDataDefensiveCopy(t *testing.T) {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}
func TestWithBodyDefensiveCopy(t *testing.T) {
body := []byte("hello")
req, err := NewRequest("http://example.com", "POST", WithBody(body))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
body[0] = 'j'
if string(req.config.Body.Bytes) != "hello" {
t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello")
}
}
func TestWithFormDataDefensiveCopy(t *testing.T) {
form := map[string][]string{
"name": []string{"alice"},
}
req, err := NewRequest("http://example.com", "POST", WithFormData(form))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}
func TestSetCustomIPDefensiveCopy(t *testing.T) {
ips := []string{"1.1.1.1", "8.8.8.8"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips)
ips[0] = "9.9.9.9"
if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" {
t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1")
}
}
func TestSetCustomDNSDefensiveCopy(t *testing.T) {
servers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers)
servers[0] = "9.9.9.9"
if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" {
t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8")
}
}

148
dialer.go
View File

@ -9,39 +9,50 @@ import (
"time"
)
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
if traceState != nil {
traceState.beginManualDNS()
defer traceState.endManualDNS()
traceState.dnsStart(TraceDNSStartInfo{Host: host})
}
ipAddrs, err := lookup()
if traceState != nil {
traceState.dnsDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), ipAddrs...),
Err: err,
})
}
return ipAddrs, err
}
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var addrs []string
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, joinResolvedHostPort(ip, port))
}
return addrs, nil
}
var (
ipAddrs []net.IPAddr
err error
)
if reqCtx.LookupIPFn != nil {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return reqCtx.LookupIPFn(ctx, host)
})
} else if len(reqCtx.CustomDNS) > 0 {
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
// 获取 IP 地址列表
var addrs []string
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
@ -58,10 +69,13 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return resolver.LookupIPAddr(ctx, host)
})
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
})
}
if err != nil {
@ -69,8 +83,41 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
}
return addrs, nil
}
func joinResolvedHostPort(host, port string) string {
if port == "" {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]"
}
return host
}
return net.JoinHostPort(host, port)
}
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
if err != nil {
return nil, err
}
// 尝试连接所有地址
@ -103,13 +150,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify自动设置
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
serverName := tlsConfig.ServerName
if serverName == "" {
serverName = reqCtx.TLSServerName
}
if serverName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
@ -118,8 +169,19 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
host = addr
}
}
serverName = host
}
if serverName != "" && tlsConfig.ServerName != serverName {
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = host
tlsConfig.ServerName = serverName
}
if traceState != nil {
traceState.markCustomTLS()
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
Network: network,
Addr: addr,
ServerName: serverName,
})
}
// 执行 TLS 握手
@ -130,9 +192,25 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
Err: err,
})
}
conn.Close()
return nil, wrapError(err, "tls handshake")
}
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
ConnectionState: tlsConn.ConnectionState(),
})
}
return tlsConn, nil
}

View File

@ -0,0 +1,144 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/url"
"testing"
)
func BenchmarkDynamicTransportCustomIP(b *testing.B) {
server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"}))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(b, nil)
defer proxy.Close()
targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-user-tls.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func benchmarkTargetURL(tb testing.TB, rawURL, host string) string {
tb.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
tb.Fatalf("url.Parse() error: %v", err)
}
port := parsed.Port()
if port == "" {
switch parsed.Scheme {
case "https":
port = "443"
default:
port = "80"
}
}
return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery)
}
func pathWithQuery(path, rawQuery string) string {
if path == "" {
path = "/"
}
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}

150
host_tls_regression_test.go Normal file
View File

@ -0,0 +1,150 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetURL("https://other.example")
if req.Err() != nil {
t.Fatalf("unexpected request error: %v", req.Err())
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
}
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{InsecureSkipVerify: true}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig == cfg {
t.Fatal("expected injected tls config to be cloned")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
t.Fatalf("unexpected request context injection: %#v", got)
}
rc := getRequestContext(req.execCtx)
if needsDynamicTransport(rc) {
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
}
if rc.TLSServerName != "" {
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
}
}
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected aggregated request context, got %#v", raw)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatal("expected tls config with skip verify")
}
if rc.TLSServerName != "example.com" {
t.Fatalf("default tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "override.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
SetHost("override.example").
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
}
func TestWithHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "option.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
got, err := resp.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer got.Close()
}

230
internal/pingcore/core.go Normal file
View File

@ -0,0 +1,230 @@
package pingcore
import (
"encoding/binary"
"net"
"os"
"sync/atomic"
"time"
)
const icmpHeaderLen = 8
type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type Options struct {
Count int
Timeout time.Duration
Interval time.Duration
Deadline time.Time
PreferIPv4 bool
PreferIPv6 bool
SourceIP net.IP
PayloadSize int
}
type Result struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var identifierSeed uint32
func NextIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&identifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func Payload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for index := 0; index < len(payload); index++ {
payload[index] = byte(index)
}
return payload
}
func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := MarshalPacket(icmp, payload)
icmp.CheckSum = Checksum(buf)
return icmp
}
func Checksum(data []byte) uint16 {
var (
sum uint32
length = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
}
func Marshal(icmp ICMP) []byte {
return MarshalPacket(icmp, nil)
}
func MarshalPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload))
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
}
func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, offset := range CandidateICMPOffsets(packet, family) {
if offset < 0 || offset+icmpHeaderLen > len(packet) {
continue
}
if packet[offset] != expectedType || packet[offset+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq {
continue
}
return true
}
return false
}
func CandidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
version := packet[0] >> 4
if version == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if version == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return DedupOffsets(offsets)
}
func DedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
seen := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, offset := range offsets {
if _, ok := seen[offset]; ok {
continue
}
seen[offset] = struct{}{}
out = append(out, offset)
}
return out
}
func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = err
}
if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = err
}
if len(targets) > 0 {
return OrderTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, nil
}
func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
return ordered
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
return ordered
}

View File

@ -0,0 +1,123 @@
package tlssniffercore
import "crypto/tls"
func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
ApplyServerTLSOverrides(out, selected)
return out
}
func ApplyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
copied := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for name, cert := range src.NameToCertificate {
copied[name] = cert
}
dst.NameToCertificate = copied
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
}

View File

@ -0,0 +1,237 @@
package tlssniffercore
import (
"bytes"
"encoding/binary"
"io"
"net"
)
type ClientHelloMeta struct {
ServerName string
LocalAddr net.Addr
RemoteAddr net.Addr
SupportedProtos []string
SupportedVersions []uint16
CipherSuites []uint16
}
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
}
type Sniffer struct{}
func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
}
if isTLS {
out.ClientHello = meta
}
return out, nil
}
func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
}
header, complete := readTLSRecordHeader(reader, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, ok := readTLSRecordHeader(reader, buf)
if len(nextHeader) < 5 || !ok {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen)
if !bodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(reader, buf, 5)
}
func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) {
if count <= 0 {
return nil, true
}
tmp := make([]byte, count)
readN, err := io.ReadFull(reader, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for index := 0; index+1 < cipherSuitesLen; index += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
}
}

View File

@ -1,405 +0,0 @@
package starnet
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return func(r *Request) error {
r.config.Network.Proxy = proxy
return nil
}
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return func(r *Request) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
r.config.DNS.CustomIP = ips
return nil
}
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return func(r *Request) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
r.config.DNS.CustomDNS = dnsServers
return nil
}
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set(key, value)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range headers {
r.config.Headers.Set(k, v)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = []byte(body)
r.config.Body.Reader = nil
return nil
}
}
// WithBodyReader 设置请求体Reader
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
r.config.Body.Bytes = data
r.config.Body.Reader = nil
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData = data
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range data {
r.config.Body.FormData[k] = []string{v}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return nil
}
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return func(r *Request) error {
r.config.AutoCalcContentLength = auto
return nil
}
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.doRaw = true
return nil
}
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return func(r *Request) error {
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return nil
}
}

112
options_body.go Normal file
View File

@ -0,0 +1,112 @@
package starnet
import (
"encoding/json"
"io"
"os"
)
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, body)
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, []byte(body))
return nil
}
}
// WithBodyReader 设置请求体Reader
// 出于避免重复写的保守策略Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
setReaderBodyConfig(&r.config.Body, reader)
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
setBytesBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, nil)
for key, value := range data {
r.config.Body.FormData[key] = []string{value}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}

132
options_config.go Normal file
View File

@ -0,0 +1,132 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateTimeout(timeout))
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateDialTimeout(timeout))
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return requestOptFromMutation(mutateProxy(proxy))
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return requestOptFromMutation(mutateDialFunc(fn))
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return requestOptFromMutation(mutateTLSConfig(tlsConfig))
}
// WithTLSServerName 设置显式 TLS ServerName/SNI。
func WithTLSServerName(serverName string) RequestOpt {
return requestOptFromMutation(mutateTLSServerName(serverName))
}
// WithTraceHooks 设置请求 trace 回调。
func WithTraceHooks(hooks *TraceHooks) RequestOpt {
return requestOptFromMutation(mutateTraceHooks(hooks))
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return requestOptFromMutation(mutateSkipTLSVerify(skip))
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return requestOptFromMutation(mutateCustomIP(ips))
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return requestOptFromMutation(mutateAddCustomIP(ip))
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return requestOptFromMutation(mutateCustomDNS(dnsServers))
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return requestOptFromMutation(mutateAddCustomDNS(dns))
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return requestOptFromMutation(mutateLookupFunc(fn))
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return requestOptFromMutation(mutateBasicAuth(username, password))
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return requestOptFromMutation(mutateAddQuery(key, value))
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return requestOptFromMutation(mutateAddQueries(queries))
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return requestOptFromMutation(mutateContentLength(length))
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoCalcContentLength(auto))
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return requestOptFromMutation(mutateUploadProgress(fn))
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return requestOptFromMutation(mutateTransport(transport))
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoFetch(auto))
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes))
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return requestOptFromMutation(mutateRawRequest(httpReq))
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return requestOptFromMutation(mutateContext(ctx))
}

99
options_header.go Normal file
View File

@ -0,0 +1,99 @@
package starnet
import "net/http"
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
return nil
}
r.config.Headers.Set(key, value)
return nil
}
}
// WithHost 设置显式 Host 头覆盖。
func WithHost(host string) RequestOpt {
return func(r *Request) error {
setRequestHostConfig(r.config, host)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for key, value := range headers {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
continue
}
r.config.Headers.Set(key, value)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}

221
ping.go
View File

@ -2,14 +2,14 @@ package starnet
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"os"
"strings"
"sync/atomic"
"time"
"b612.me/starnet/internal/pingcore"
)
const (
@ -18,7 +18,6 @@ const (
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpHeaderLen = 8
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
@ -26,13 +25,7 @@ const (
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type ICMP = pingcore.ICMP
type pingSocketSpec struct {
network string
@ -42,53 +35,20 @@ type pingSocketSpec struct {
}
// PingOptions controls ping probing behavior.
type PingOptions struct {
Count int // ping attempts for Pingable, default 3
Timeout time.Duration // per-attempt timeout, default 2s
Interval time.Duration // delay between attempts, default 0
Deadline time.Time // overall deadline for Pingable/PingWithContext
PreferIPv4 bool // prefer IPv4 targets
PreferIPv6 bool // prefer IPv6 targets
SourceIP net.IP // optional source IP for raw socket bind
PayloadSize int // ICMP payload bytes, default 0
}
type PingOptions = pingcore.Options
type PingResult struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var pingIdentifierSeed uint32
type PingResult = pingcore.Result
func nextPingIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&pingIdentifierSeed, 1)
return uint16((pid + n) & 0xffff)
return pingcore.NextIdentifier()
}
func pingPayload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for i := 0; i < len(payload); i++ {
payload[i] = byte(i)
}
return payload
return pingcore.Payload(size)
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, payload)
icmp.CheckSum = checkSum(buf)
return icmp
return pingcore.BuildICMP(seq, identifier, typ, payload)
}
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
@ -120,8 +80,8 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
return res, wrapError(err, "ping write request")
}
tStart := time.Now()
deadline := tStart.Add(timeout)
startedAt := time.Now()
deadline := startedAt.Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
}
@ -150,108 +110,34 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(tStart)
res.Duration = time.Since(startedAt)
return res, nil
}
}
}
func checkSum(data []byte) uint16 {
var (
sum uint32
length int = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
return pingcore.Checksum(data)
}
func marshalICMP(icmp ICMP) []byte {
return marshalICMPPacket(icmp, nil)
return pingcore.Marshal(icmp)
}
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload))
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
return pingcore.MarshalPacket(icmp, payload)
}
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, off := range candidateICMPOffsets(packet, family) {
if off < 0 || off+icmpHeaderLen > len(packet) {
continue
}
if packet[off] != expectedType || packet[off+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
continue
}
return true
}
return false
return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
}
func candidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
ver := packet[0] >> 4
if ver == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return dedupOffsets(offsets)
return pingcore.CandidateICMPOffsets(packet, family)
}
func dedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
m := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, off := range offsets {
if _, ok := m[off]; ok {
continue
}
m[off] = struct{}{}
out = append(out, off)
}
return out
return pingcore.DedupOffsets(offsets)
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
@ -297,70 +183,18 @@ func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = e
}
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = e
}
if len(targets) > 0 {
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
if err != nil {
return nil, err
}
if len(targets) == 0 {
return nil, ErrPingNoResolvedTarget
}
return targets, nil
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
return ordered
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
return ordered
return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
}
func normalizePingDialError(err error) error {
@ -450,7 +284,6 @@ func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOpt
return resp, nil
}
// 权限问题通常与地址族无关,继续重试意义不大。
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
@ -501,8 +334,8 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
}
var lastErr error
for i := 0; i < cfg.Count; i++ {
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
for index := 0; index < cfg.Count; index++ {
_, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
if err == nil {
return true, nil
}
@ -512,7 +345,7 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
break
}
if i < cfg.Count-1 && cfg.Interval > 0 {
if index < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():

110
proxy_custom_ip_test.go Normal file
View File

@ -0,0 +1,110 @@
package starnet
import (
"fmt"
"net"
"net/http"
"testing"
)
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
proxyServer := newIPv4ConnectProxyServer(t, nil)
defer proxyServer.Close()
targetHost := "proxy-custom-ip.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
targets := proxyServer.Targets()
if len(targets) != 1 {
t.Fatalf("connect targets=%v; want 1 target", targets)
}
gotConnectTarget := targets[0]
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
if gotConnectTarget != wantConnectTarget {
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
}
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
targetHost := "custom-ip-direct.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}

331
proxy_local_helpers_test.go Normal file
View File

@ -0,0 +1,331 @@
package starnet
import (
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type connectProxyServer struct {
*httptest.Server
mu sync.Mutex
targets []string
}
func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.Start()
return server
}
func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.StartTLS()
return server
}
func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) {
t.Helper()
testT, ok := t.(*testing.T)
if !ok {
t.Fatal("newTrustedIPv4TLSServer requires *testing.T")
}
certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("X509KeyPair: %v", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certPEM) {
t.Fatal("AppendCertsFromPEM returned false")
}
server := httptest.NewUnstartedServer(handler)
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server.Listener = listener
server.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
server.StartTLS()
return server, pool
}
func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string {
t.Helper()
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
if err != nil {
t.Fatalf("split host port: %v", err)
}
return fmt.Sprintf("https://%s:%s", host, port)
}
func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer {
t.Helper()
proxy := &connectProxyServer{}
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
proxy.mu.Lock()
proxy.targets = append(proxy.targets, r.Host)
proxy.mu.Unlock()
targetConn, err := dialTarget(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
return proxy
}
func (p *connectProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
type socks5ProxyServer struct {
ln net.Listener
addr string
dial func(target string) (net.Conn, error)
stopCh chan struct{}
wg sync.WaitGroup
mu sync.Mutex
targets []string
}
func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer {
t.Helper()
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4 socks5: %v", err)
}
proxy := &socks5ProxyServer{
ln: ln,
addr: ln.Addr().String(),
dial: dialTarget,
stopCh: make(chan struct{}),
}
proxy.wg.Add(1)
go func() {
defer proxy.wg.Done()
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-proxy.stopCh:
return
default:
return
}
}
proxy.wg.Add(1)
go func(c net.Conn) {
defer proxy.wg.Done()
proxy.handleConn(t, c)
}(conn)
}
}()
return proxy
}
func (p *socks5ProxyServer) URL() string {
return "socks5://" + p.addr
}
func (p *socks5ProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
func (p *socks5ProxyServer) Close() {
close(p.stopCh)
_ = p.ln.Close()
p.wg.Wait()
}
func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) {
t.Helper()
closeConn := true
defer func() {
if closeConn {
_ = conn.Close()
}
}()
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return
}
if header[0] != 0x05 {
return
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return
}
if _, err := conn.Write([]byte{0x05, 0x00}); err != nil {
return
}
reqHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, reqHeader); err != nil {
return
}
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
_, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
host, err := readSOCKS5Addr(conn, reqHeader[3])
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
portBytes := make([]byte, 2)
if _, err := io.ReadFull(conn, portBytes); err != nil {
return
}
target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes)))
p.mu.Lock()
p.targets = append(p.targets, target)
p.mu.Unlock()
targetConn, err := p.dial(target)
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil {
targetConn.Close()
return
}
closeConn = false
relayProxyConns(conn, targetConn)
}
func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) {
switch atyp {
case 0x01:
buf := make([]byte, 4)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
case 0x03:
var size [1]byte
if _, err := io.ReadFull(r, size[:]); err != nil {
return "", err
}
buf := make([]byte, int(size[0]))
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return string(buf), nil
case 0x04:
buf := make([]byte, 16)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
default:
return "", fmt.Errorf("unsupported atyp: %d", atyp)
}
}
func relayProxyConns(left, right net.Conn) {
var once sync.Once
closeBoth := func() {
_ = left.Close()
_ = right.Close()
}
go func() {
_, _ = io.Copy(left, right)
once.Do(closeBoth)
}()
go func() {
_, _ = io.Copy(right, left)
once.Do(closeBoth)
}()
}

View File

@ -22,14 +22,131 @@ type Request struct {
httpClient *http.Client
httpReq *http.Request
retry *retryPolicy
traceHooks *TraceHooks
traceState *traceState
applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改)
autoFetch bool // 是否自动获取响应体
rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request
rawTemplate *http.Request
}
func normalizeContext(ctx context.Context) context.Context {
if ctx != nil {
return ctx
}
return context.Background()
}
func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) {
if httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
cloned := httpReq.Clone(normalizeContext(ctx))
switch {
case httpReq.Body == nil || httpReq.Body == http.NoBody:
cloned.Body = httpReq.Body
case httpReq.GetBody != nil:
body, err := httpReq.GetBody()
if err != nil {
return cloned, wrapError(err, "clone raw request body")
}
cloned.Body = body
default:
return cloned, fmt.Errorf("cannot clone raw request with non-replayable body")
}
return cloned, nil
}
func (r *Request) rawBaseRequest() *http.Request {
if r == nil {
return nil
}
if r.rawTemplate != nil {
return r.rawTemplate
}
return r.httpReq
}
func (r *Request) invalidatePreparedState() {
if r == nil {
return
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceState = nil
r.httpClient = nil
wasApplied := r.applied
r.applied = false
if !wasApplied || r.doRaw {
return
}
if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil {
r.err = err
}
}
func (r *Request) rebuildPreparedRequestBase() error {
if r == nil || r.doRaw {
return nil
}
ctx := r.ctx
if ctx == nil {
ctx = context.Background()
}
httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil)
if err != nil {
return wrapError(err, "rebuild http request")
}
r.httpReq = httpReq
r.syncRequestHost()
return nil
}
func (r *Request) rebuildRawRequestBase() error {
if r == nil || !r.doRaw {
return nil
}
baseReq := r.rawBaseRequest()
rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx))
if err != nil && baseReq != nil && baseReq == r.httpReq {
r.httpReq = baseReq.WithContext(normalizeContext(r.ctx))
return nil
}
if rawReq != nil {
r.httpReq = rawReq
}
return err
}
func (r *Request) rebuildExecutionRequestBase() error {
if r == nil {
return nil
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceState = nil
r.applied = false
if r.doRaw {
return r.rebuildRawRequestBase()
}
return r.rebuildPreparedRequestBase()
}
// newRequest 创建新请求(内部使用)
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
ctx = normalizeContext(ctx)
if method == "" {
method = http.MethodGet
}
@ -133,6 +250,7 @@ func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
ctx = normalizeContext(ctx)
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return &Request{
@ -163,16 +281,24 @@ func (r *Request) Clone() *Request {
client: r.client,
httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
traceHooks: r.traceHooks,
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
rawSourceExternal: r.rawSourceExternal,
}
// 重新创建 http.Request
if !r.doRaw {
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
} else {
cloned.httpReq = r.httpReq
rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx)
cloned.rawTemplate = rawTemplate
cloned.httpReq = rawTemplate
if err != nil && cloned.err == nil {
cloned.err = err
}
}
return cloned
@ -190,12 +316,7 @@ func (r *Request) Context() context.Context {
// SetContext 设置 context
func (r *Request) SetContext(ctx context.Context) *Request {
if r.err != nil {
return r
}
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return r
return r.applyMutation(mutateContext(ctx))
}
// Method 获取 HTTP 方法
@ -215,7 +336,13 @@ func (r *Request) SetMethod(method string) *Request {
}
r.method = method
if r.httpReq != nil {
r.httpReq.Method = method
}
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate.Method = method
}
r.invalidatePreparedState()
return r
}
@ -243,45 +370,74 @@ func (r *Request) SetURL(urlStr string) *Request {
r.url = urlStr
u.Host = removeEmptyPort(u.Host)
r.httpReq.Host = u.Host
r.httpReq.URL = u
// 更新 TLS ServerName
if r.config.TLS.Config != nil {
r.config.TLS.Config.ServerName = u.Hostname()
}
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
func (r *Request) effectiveRequestHost() string {
if r == nil {
return ""
}
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil && r.httpReq.URL != nil {
return removeEmptyPort(r.httpReq.URL.Host)
}
if r.url == "" {
return ""
}
u, err := url.Parse(r.url)
if err != nil {
return ""
}
return removeEmptyPort(u.Host)
}
func (r *Request) syncRequestHost() {
if r == nil || r.httpReq == nil {
return
}
r.httpReq.Host = r.effectiveRequestHost()
}
// RawRequest 获取底层 http.Request
func (r *Request) RawRequest() *http.Request {
if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied {
return r.rawTemplate
}
return r.httpReq
}
// SetRawRequest 设置底层 http.Request启用原始模式
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
if r.err != nil {
return r
}
r.httpReq = httpReq
r.doRaw = true
if httpReq == nil {
r.err = fmt.Errorf("httpReq cannot be nil")
return r
}
return r
return r.applyMutation(mutateRawRequest(httpReq))
}
// EnableRawMode 启用原始模式(不修改请求)
func (r *Request) EnableRawMode() *Request {
if r.doRaw {
return r
}
r.doRaw = true
r.invalidatePreparedState()
return r
}
// DisableRawMode 禁用原始模式
func (r *Request) DisableRawMode() *Request {
if !r.doRaw {
return r
}
if r.rawSourceExternal {
r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest")
return r
}
r.doRaw = false
r.invalidatePreparedState()
return r
}
@ -329,6 +485,10 @@ func (r *Request) Do() (*Response, error) {
}
func (r *Request) doOnce() (*Response, error) {
if err := r.rebuildExecutionRequestBase(); err != nil {
return nil, wrapError(err, "rebuild execution request")
}
// 准备请求
if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request")

View File

@ -1,16 +1,9 @@
package starnet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)
// SetBody 设置请求体(字节)
@ -21,12 +14,13 @@ func (r *Request) SetBody(body []byte) *Request {
if r.doRaw {
return r
}
r.config.Body.Bytes = body
r.config.Body.Reader = nil
setBytesBodyConfig(&r.config.Body, body)
r.invalidatePreparedState()
return r
}
// SetBodyReader 设置请求体Reader
// SetBodyReader 设置请求体Reader
// 出于避免重复写的保守策略Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
@ -34,8 +28,8 @@ func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.doRaw {
return r
}
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
setReaderBodyConfig(&r.config.Body, reader)
r.invalidatePreparedState()
return r
}
@ -67,7 +61,8 @@ func (r *Request) SetFormData(data map[string][]string) *Request {
if r.doRaw {
return r
}
r.config.Body.FormData = cloneStringMapSlice(data)
setFormBodyConfig(&r.config.Body, data)
r.invalidatePreparedState()
return r
}
@ -79,7 +74,9 @@ func (r *Request) AddFormData(key, value string) *Request {
if r.doRaw {
return r
}
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
r.invalidatePreparedState()
return r
}
@ -91,9 +88,11 @@ func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.doRaw {
return r
}
for k, v := range data {
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
ensureFormMode(&r.config.Body)
for key, value := range data {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
}
r.invalidatePreparedState()
return r
}
@ -109,6 +108,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
@ -116,6 +116,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
@ -132,6 +133,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
@ -139,6 +141,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
@ -155,6 +158,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
@ -162,6 +166,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
FileSize: stat.Size(),
FileType: fileType,
})
r.invalidatePreparedState()
return r
}
@ -177,6 +182,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
@ -184,6 +190,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
FileSize: size,
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
@ -199,6 +206,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
@ -206,243 +214,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
FileSize: size,
FileType: fileType,
})
r.invalidatePreparedState()
return r
}
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err // ← 失败时不设置 applied
}
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
// 原始模式不修改请求内容
if !r.doRaw {
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
}
execCtx := r.ctx
if !r.doRaw {
// raw 模式下不注入请求级网络配置,只应用 context/超时。
execCtx = injectRequestConfig(execCtx, r.config)
}
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
if r.config.Network.Timeout > 0 {
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
r.execCtx = execCtx
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
// 优先使用请求关联的 Client
if r.client != nil {
return r.client.HTTPClient(), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
// 默认全局 client
return DefaultHTTPClient(), nil
}

View File

@ -1,282 +0,0 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.Timeout = timeout
return r
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.DialTimeout = timeout
return r
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
if r.err != nil {
return r
}
r.config.Network.Proxy = proxy
return r
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
if r.err != nil {
return r
}
r.config.Network.DialFunc = fn
return r
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
if r.err != nil {
return r
}
r.config.TLS.Config = tlsConfig
return r
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
if r.err != nil {
return r
}
r.config.TLS.SkipVerify = skip
return r
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
if r.err != nil {
return r
}
// 验证 IP 格式
for _, ip := range ips {
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
}
r.config.DNS.CustomIP = ips
return r
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return r
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
if r.err != nil {
return r
}
// 验证 DNS 服务器格式
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
}
r.config.DNS.CustomDNS = dnsServers
return r
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return r
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
if r.err != nil {
return r
}
r.config.DNS.LookupFunc = fn
return r
}
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
if r.err != nil {
return r
}
r.config.BasicAuth = [2]string{username, password}
return r
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
if r.err != nil {
return r
}
r.config.ContentLength = length
return r
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
return r
}
r.config.AutoCalcContentLength = auto
return r
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
if r.err != nil {
return r
}
r.config.Transport = transport
r.config.CustomTransport = true
return r
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
if r.err != nil {
return r
}
r.config.UploadProgress = fn
return r
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
if r.err != nil {
return r
}
if maxBytes < 0 {
r.err = fmt.Errorf("max response body bytes must be >= 0")
return r
}
r.config.MaxRespBodyBytes = maxBytes
return r
}
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = append(r.config.Queries[key], value)
return r
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = []string{value}
return r
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil {
return r
}
r.config.Queries = cloneStringMapSlice(queries)
return r
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
if r.err != nil {
return r
}
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return r
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
if r.err != nil {
return r
}
delete(r.config.Queries, key)
return r
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
if r.err != nil {
return r
}
values, ok := r.config.Queries[key]
if !ok {
return r
}
newValues := make([]string, 0, len(values))
for _, v := range values {
if v != value {
newValues = append(newValues, v)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
} else {
r.config.Queries[key] = newValues
}
return r
}

34
request_execution.go Normal file
View File

@ -0,0 +1,34 @@
package starnet
import "net/http"
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
return r.applyMutation(mutateBasicAuth(username, password))
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
return r.applyMutation(mutateContentLength(length))
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
return r.applyMutation(mutateAutoCalcContentLength(auto))
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
return r.applyMutation(mutateTransport(transport))
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
return r.applyMutation(mutateUploadProgress(fn))
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
return r.applyMutation(mutateMaxRespBodyBytes(maxBytes))
}

View File

@ -0,0 +1,172 @@
package starnet
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
)
func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/path", http.MethodPost).
SetHeader("X-Test", "one").
AddQuery("q", "v").
SetBodyReader(strings.NewReader("payload"))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := r.Context().Err(); err != nil {
t.Fatalf("request context already done: %v", err)
}
if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" {
t.Fatalf("header values=%v", values)
}
if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" {
t.Fatalf("query values=%v", values)
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q", string(body))
}
n := atomic.AddInt32(&attempts, 1)
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))),
Request: r,
}, nil
}),
}}
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if err := resp1.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).
SetRawRequest(rawReq).
SetProxy("http://proxy.example:8080").
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected request context, got %#v", raw)
}
if rc.Proxy != "http://proxy.example:8080" {
t.Fatalf("proxy=%q", rc.Proxy)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatalf("tls config=%#v", rc.TLSConfig)
}
if rc.TLSServerName != "override.example" {
t.Fatalf("tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetFormDataOverridesBytesBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetBodyString("stale").
SetFormData(map[string][]string{"k": []string{"v"}})
if req.config.Body.Mode != bodyModeForm {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 {
t.Fatalf("unexpected stale body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
body, err := req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "k=v" {
t.Fatalf("body=%q; want k=v", string(data))
}
}
func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "payload.txt")
if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetJSON(map[string]string{"old": "json-only"}).
AddFile("file", filePath)
if req.config.Body.Mode != bodyModeMultipart {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil {
t.Fatalf("unexpected stale simple body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
data, err := io.ReadAll(req.httpReq.Body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") {
t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type"))
}
if !strings.Contains(string(data), "file-body") {
t.Fatalf("multipart body missing file content: %q", string(data))
}
if strings.Contains(string(data), "json-only") {
t.Fatalf("multipart body still contains stale json: %q", string(data))
}
}

View File

@ -4,6 +4,25 @@ import (
"net/http"
)
func isHostHeaderKey(key string) bool {
return http.CanonicalHeaderKey(key) == "Host"
}
func setRequestHostConfig(config *RequestConfig, host string) {
if config == nil {
return
}
if config.Headers == nil {
config.Headers = make(http.Header)
}
config.Host = host
if host == "" {
config.Headers.Del("Host")
return
}
config.Headers.Set("Host", host)
}
// SetHeader 设置 Header覆盖
func (r *Request) SetHeader(key, value string) *Request {
if r.err != nil {
@ -12,7 +31,11 @@ func (r *Request) SetHeader(key, value string) *Request {
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Set(key, value)
r.invalidatePreparedState()
return r
}
@ -24,7 +47,11 @@ func (r *Request) AddHeader(key, value string) *Request {
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Add(key, value)
r.invalidatePreparedState()
return r
}
@ -37,6 +64,9 @@ func (r *Request) SetHeaders(headers http.Header) *Request {
return r
}
r.config.Headers = cloneHeader(headers)
r.config.Host = r.config.Headers.Get("Host")
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
@ -49,8 +79,14 @@ func (r *Request) AddHeaders(headers map[string]string) *Request {
return r
}
for k, v := range headers {
if isHostHeaderKey(k) {
setRequestHostConfig(r.config, v)
continue
}
r.config.Headers.Add(k, v)
}
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
@ -62,18 +98,56 @@ func (r *Request) DeleteHeader(key string) *Request {
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, "")
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
r.config.Headers.Del(key)
r.invalidatePreparedState()
return r
}
// GetHeader 获取 Header
func (r *Request) GetHeader(key string) string {
if isHostHeaderKey(key) {
return r.config.Host
}
return r.config.Headers.Get(key)
}
// Headers 获取所有 Headers
func (r *Request) Headers() http.Header {
return r.config.Headers
if r == nil || r.config == nil {
return make(http.Header)
}
return cloneHeader(r.config.Headers)
}
// SetHost 设置请求 Host 头覆盖。
func (r *Request) SetHost(host string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setRequestHostConfig(r.config, host)
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// Host 获取显式 Host 覆盖。
func (r *Request) Host() string {
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil {
return r.httpReq.Host
}
return ""
}
// SetContentType 设置 Content-Type
@ -104,7 +178,8 @@ func (r *Request) AddCookie(cookie *http.Cookie) *Request {
if r.doRaw {
return r
}
r.config.Cookies = append(r.config.Cookies, cookie)
r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie))
r.invalidatePreparedState()
return r
}
@ -134,7 +209,8 @@ func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
if r.doRaw {
return r
}
r.config.Cookies = cookies
r.config.Cookies = cloneCookies(cookies)
r.invalidatePreparedState()
return r
}
@ -153,12 +229,16 @@ func (r *Request) AddCookies(cookies map[string]string) *Request {
Path: "/",
})
}
r.invalidatePreparedState()
return r
}
// Cookies 获取所有 Cookies
func (r *Request) Cookies() []*http.Cookie {
return r.config.Cookies
if r == nil || r.config == nil {
return nil
}
return cloneCookies(r.config.Cookies)
}
// ResetHeaders 重置所有 Headers
@ -167,6 +247,9 @@ func (r *Request) ResetHeaders() *Request {
return r
}
r.config.Headers = make(http.Header)
r.config.Host = ""
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
@ -176,5 +259,6 @@ func (r *Request) ResetCookies() *Request {
return r
}
r.config.Cookies = []*http.Cookie{}
r.invalidatePreparedState()
return r
}

69
request_multipart.go Normal file
View File

@ -0,0 +1,69 @@
package starnet
import (
"context"
"io"
"mime/multipart"
"os"
)
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody(execCtx context.Context) error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
go func() {
defer pw.Close()
defer writer.Close()
for key, values := range r.config.Body.FormData {
for _, value := range values {
if err := writer.WriteField(key, value); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
for _, file := range r.config.Body.Files {
if err := r.writeFile(execCtx, writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error {
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
_, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}

326
request_mutation.go Normal file
View File

@ -0,0 +1,326 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"time"
)
type requestMutation func(*Request) error
func (r *Request) applyMutation(mutation requestMutation) *Request {
if r == nil || r.err != nil {
return r
}
if err := mutation(r); err != nil {
r.err = err
return r
}
r.invalidatePreparedState()
return r
}
func requestOptFromMutation(mutation requestMutation) RequestOpt {
return func(r *Request) error {
if r == nil {
return nil
}
return mutation(r)
}
}
func validateCustomIPs(ips []string) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
return nil
}
func validateCustomDNS(dnsServers []string) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
return nil
}
func parseProxyURL(proxy string) (*url.URL, error) {
if proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
if proxyURL.Scheme == "" {
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
}
if proxyURL.Host == "" {
return nil, fmt.Errorf("proxy host is required: %s", proxy)
}
return proxyURL, nil
}
func mutateTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
func mutateDialTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
func mutateProxy(proxy string) requestMutation {
return func(r *Request) error {
if _, err := parseProxyURL(proxy); err != nil {
return err
}
r.config.Network.Proxy = proxy
return nil
}
}
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
func mutateTLSServerName(serverName string) requestMutation {
return func(r *Request) error {
r.config.TLS.ServerName = serverName
return nil
}
}
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
return func(r *Request) error {
r.traceHooks = hooks
return nil
}
}
func mutateSkipTLSVerify(skip bool) requestMutation {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
func mutateCustomIP(ips []string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs(ips); err != nil {
return err
}
r.config.DNS.CustomIP = cloneStringSlice(ips)
return nil
}
}
func mutateAddCustomIP(ip string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs([]string{ip}); err != nil {
return err
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
func mutateCustomDNS(dnsServers []string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS(dnsServers); err != nil {
return err
}
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
return nil
}
}
func mutateAddCustomDNS(dns string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS([]string{dns}); err != nil {
return err
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
func mutateBasicAuth(username, password string) requestMutation {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
func mutateContentLength(length int64) requestMutation {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
func mutateAutoCalcContentLength(auto bool) requestMutation {
return func(r *Request) error {
if r.doRaw {
return fmt.Errorf("cannot set auto calc content length in raw mode")
}
r.config.AutoCalcContentLength = auto
return nil
}
}
func mutateTransport(transport *http.Transport) requestMutation {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
func mutateAutoFetch(auto bool) requestMutation {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
func mutateContext(ctx context.Context) requestMutation {
return func(r *Request) error {
ctx = normalizeContext(ctx)
r.ctx = ctx
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate = r.rawTemplate.WithContext(ctx)
}
if r.httpReq != nil {
r.httpReq = r.httpReq.WithContext(ctx)
}
return nil
}
}
func mutateRawRequest(httpReq *http.Request) requestMutation {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.rawTemplate = httpReq
r.ctx = normalizeContext(httpReq.Context())
r.method = httpReq.Method
if httpReq.URL != nil {
r.url = httpReq.URL.String()
}
r.doRaw = true
r.rawSourceExternal = true
return nil
}
}
func mutateAddQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
func mutateSetQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = []string{value}
return nil
}
}
func mutateSetQueries(queries map[string][]string) requestMutation {
return func(r *Request) error {
r.config.Queries = cloneStringMapSlice(queries)
return nil
}
}
func mutateAddQueries(queries map[string]string) requestMutation {
return func(r *Request) error {
for key, value := range queries {
r.config.Queries[key] = append(r.config.Queries[key], value)
}
return nil
}
}
func mutateDeleteQuery(key string) requestMutation {
return func(r *Request) error {
delete(r.config.Queries, key)
return nil
}
}
func mutateDeleteQueryValue(key, value string) requestMutation {
return func(r *Request) error {
values, ok := r.config.Queries[key]
if !ok {
return nil
}
newValues := make([]string, 0, len(values))
for _, item := range values {
if item != value {
newValues = append(newValues, item)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
return nil
}
r.config.Queries[key] = newValues
return nil
}
}

71
request_network.go Normal file
View File

@ -0,0 +1,71 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateTimeout(timeout))
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateDialTimeout(timeout))
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
return r.applyMutation(mutateProxy(proxy))
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
return r.applyMutation(mutateDialFunc(fn))
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
return r.applyMutation(mutateTLSConfig(tlsConfig))
}
// SetTLSServerName 设置显式 TLS ServerName/SNI。
func (r *Request) SetTLSServerName(serverName string) *Request {
return r.applyMutation(mutateTLSServerName(serverName))
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
return r.applyMutation(mutateSkipTLSVerify(skip))
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
return r.applyMutation(mutateCustomIP(ips))
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
return r.applyMutation(mutateAddCustomIP(ip))
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
return r.applyMutation(mutateCustomDNS(dnsServers))
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
return r.applyMutation(mutateAddCustomDNS(dns))
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
return r.applyMutation(mutateLookupFunc(fn))
}

314
request_prepare.go Normal file
View File

@ -0,0 +1,314 @@
package starnet
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
)
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
if httpReq == nil {
return
}
httpReq.Body = io.NopCloser(bytes.NewReader(data))
httpReq.ContentLength = int64(len(data))
httpReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
}
func clearSimpleBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Bytes = nil
body.Reader = nil
}
func resetFormBodyState(body *BodyConfig) {
if body == nil {
return
}
body.FormData = make(map[string][]string)
}
func resetMultipartBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Files = nil
}
func setBytesBodyConfig(body *BodyConfig, data []byte) {
if body == nil {
return
}
body.Mode = bodyModeBytes
body.Bytes = cloneBytes(data)
body.Reader = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
if body == nil {
return
}
body.Mode = bodyModeReader
body.Reader = reader
body.Bytes = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
if body == nil {
return
}
body.Mode = bodyModeForm
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = cloneStringMapSlice(data)
}
func ensureFormMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = make(map[string][]string)
body.Mode = bodyModeForm
}
func ensureMultipartMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
if body.Mode != bodyModeForm {
clearSimpleBodyState(body)
body.FormData = make(map[string][]string)
}
body.Mode = bodyModeMultipart
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
}
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
// applyBody 应用请求体
func (r *Request) applyBody(execCtx context.Context) error {
r.httpReq.Body = nil
r.httpReq.GetBody = nil
r.httpReq.ContentLength = 0
switch r.config.Body.Mode {
case bodyModeReader:
if r.config.Body.Reader == nil {
return nil
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
case *bytes.Reader:
data, err := snapshotBytesReader(reader)
if err != nil {
return wrapError(err, "snapshot bytes reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
case *strings.Reader:
data, err := snapshotStringReader(reader)
if err != nil {
return wrapError(err, "snapshot strings reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
default:
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(reader.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(reader.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(reader.Len())
}
return nil
case bodyModeBytes:
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
return nil
case bodyModeMultipart:
return r.applyMultipartBody(execCtx)
case bodyModeForm:
values := url.Values{}
for key, items := range r.config.Body.FormData {
for _, value := range items {
values.Add(key, value)
}
}
encoded := values.Encode()
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
return nil
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() (err error) {
if r.applied {
return nil
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
execCtx := r.ctx
if execCtx == nil {
execCtx = context.Background()
}
defaultTLSServerName := ""
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
defaultTLSServerName = r.httpReq.URL.Hostname()
}
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
var traceState *traceState
if r.traceHooks != nil {
traceState = newTraceState(r.traceHooks)
execCtx = withTraceState(execCtx, traceState)
if clientTrace := traceState.clientTrace(); clientTrace != nil {
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
}
}
var cancel context.CancelFunc
if r.config.Network.Timeout > 0 {
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
defer func() {
if err != nil && cancel != nil {
cancel()
}
}()
if r.httpClient == nil {
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err
}
}
if !r.doRaw {
if len(r.config.Queries) > 0 {
query := r.httpReq.URL.Query()
for key, values := range r.config.Queries {
for _, value := range values {
query.Add(key, value)
}
}
r.httpReq.URL.RawQuery = query.Encode()
}
for key, values := range r.config.Headers {
if isHostHeaderKey(key) {
continue
}
for _, value := range values {
r.httpReq.Header.Add(key, value)
}
}
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
if err := r.applyBody(execCtx); err != nil {
return err
}
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
setReplayableRequestBodyBytes(r.httpReq, data)
}
r.syncRequestHost()
}
r.execCtx = execCtx
r.traceState = traceState
r.cancel = cancel
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
if r.client != nil {
return r.client.HTTPClient(), nil
}
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
return DefaultHTTPClient(), nil
}

View File

@ -0,0 +1,335 @@
package starnet
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetHeader("X-Test", "one").
SetBodyString("first")
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
Request: r,
}, nil
}),
}}
if _, err := req.HTTPClient(); err != nil {
t.Fatalf("HTTPClient() error: %v", err)
}
req.SetHeader("X-Test", "two").SetBodyString("second")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != "two:second" {
t.Fatalf("body=%q; want %q", body, "two:second")
}
}
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
}
select {
case <-time.After(50 * time.Millisecond):
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
case <-r.Context().Done():
return nil, r.Context().Err()
}
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
_ = resp.Close()
_, err = req.SetTimeout(10 * time.Millisecond).Do()
if err == nil {
t.Fatal("second Do() succeeded; want timeout error")
}
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("second Do() error=%v; want timeout", err)
}
}
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost)
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, pr)
_ = pr.Close()
close(done)
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := req.writeFile(ctx, writer, RequestFile{
FormName: "file",
FileName: "payload.txt",
FileData: strings.NewReader("payload"),
FileSize: int64(len("payload")),
})
_ = writer.Close()
_ = pw.Close()
<-done
if !errors.Is(err, context.Canceled) {
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
}
}
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
}
}
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
tests := []struct {
name string
req *Request
want string
}{
{
name: "bytes",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
want: "payload",
},
{
name: "bytes-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
want: "payload",
},
{
name: "strings-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
want: "payload",
},
{
name: "form-data",
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
want: "k=v",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
if tt.req.httpReq.GetBody == nil {
t.Fatal("GetBody is nil")
}
body, err := tt.req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != tt.want {
t.Fatalf("body=%q; want %q", string(data), tt.want)
}
})
}
}
type replayRoundTripper struct {
attempts int
bodies []string
}
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
rt.attempts++
rt.bodies = append(rt.bodies, string(body))
if rt.attempts == 1 {
return nil, errors.New("first target failed")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err != nil {
t.Fatalf("roundTripResolvedTargets() error: %v", err)
}
defer resp.Body.Close()
if len(rt.bodies) != 2 {
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
}
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
}
}
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err == nil {
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
}
if len(rt.bodies) != 1 {
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
}
if rt.bodies[0] != "payload" {
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
}
}
func TestRetryReplayableReaderBody(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload")).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q; want payload", string(body))
}
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestWithProxyInvalidReturnsError(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewRequest() succeeded; want invalid proxy error")
}
}
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
client := NewClientNoErr()
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
}
}
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
_, err := NewClient(WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewClient() succeeded; want invalid proxy error")
}
}

31
request_query.go Normal file
View File

@ -0,0 +1,31 @@
package starnet
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
return r.applyMutation(mutateAddQuery(key, value))
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
return r.applyMutation(mutateSetQuery(key, value))
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
return r.applyMutation(mutateSetQueries(queries))
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
return r.applyMutation(mutateAddQueries(queries))
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
return r.applyMutation(mutateDeleteQuery(key))
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
return r.applyMutation(mutateDeleteQueryValue(key, value))
}

View File

@ -0,0 +1,168 @@
package starnet
import (
"io"
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
)
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestSetContextNilUsesBackground(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("request context is nil")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.SetContext(nil).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if req.Context() == nil {
t.Fatal("request Context() is nil")
}
}
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
var hits int32
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("retry request context is nil")
}
if atomic.AddInt32(&hits, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.
SetTimeout(DefaultTimeout).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("hits=%d; want 2", got)
}
}
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
rawReq.Header.Set("X-Test", "one")
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() != nil {
t.Fatalf("Clone() err = %v", cloned.Err())
}
if cloned.RawRequest() == rawReq {
t.Fatal("raw request pointer reused")
}
cloned.RawRequest().Header.Set("X-Test", "two")
if rawReq.Header.Get("X-Test") != "one" {
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
}
body, err := cloned.RawRequest().GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "payload" {
t.Fatalf("body=%q; want payload", string(data))
}
}
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
rawReq := &http.Request{
Method: http.MethodPost,
URL: mustParseURL(t, "http://example.com/upload"),
Header: make(http.Header),
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
}
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() == nil {
t.Fatal("Clone() should fail for non-replayable raw body")
}
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
}
}
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
if req.Err() == nil {
t.Fatal("DisableRawMode() should set error")
}
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
t.Fatalf("DisableRawMode() err=%v", req.Err())
}
if !req.doRaw {
t.Fatal("request should remain in raw mode")
}
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
parsed, err := url.Parse(raw)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
return parsed
}

6
request_trace.go Normal file
View File

@ -0,0 +1,6 @@
package starnet
// SetTraceHooks 设置请求 trace 回调。
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
return r.applyMutation(mutateTraceHooks(hooks))
}

View File

@ -1,6 +1,7 @@
package starnet
import (
"bytes"
"context"
"errors"
"fmt"
@ -9,6 +10,7 @@ import (
"math/rand"
"net"
"net/http"
"strings"
"time"
)
@ -87,6 +89,9 @@ func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
return policy, nil
}
// WithRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
@ -98,6 +103,9 @@ func WithRetry(max int, opts ...RetryOpt) RequestOpt {
}
}
// SetRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
@ -226,10 +234,10 @@ func (r *Request) doWithRetry() (*Response, error) {
return r.doOnce()
}
retryCtx := r.ctx
retryCtx := normalizeContext(r.ctx)
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
}
defer retryCancel()
@ -238,6 +246,12 @@ func (r *Request) doWithRetry() (*Response, error) {
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptNo := attempt + 1
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
})
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
@ -248,7 +262,19 @@ func (r *Request) doWithRetry() (*Response, error) {
resp.request = r
}
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
StatusCode: statusCode,
Err: err,
WillRetry: willRetry,
})
if !willRetry {
return resp, err
}
@ -262,6 +288,10 @@ func (r *Request) doWithRetry() (*Response, error) {
if delay <= 0 {
continue
}
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
Attempt: attemptNo,
Delay: delay,
})
timer := time.NewTimer(delay)
select {
@ -293,19 +323,9 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
return attempt, nil
}
if r.httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
raw := r.httpReq.Clone(ctx)
if r.httpReq.GetBody != nil {
body, err := r.httpReq.GetBody()
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
if err != nil {
return nil, wrapError(err, "get raw request body")
}
raw.Body = body
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
return nil, fmt.Errorf("raw request body is not replayable")
return nil, err
}
attempt.httpReq = raw
@ -316,6 +336,9 @@ func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
@ -347,20 +370,40 @@ func isReplayableRequest(r *Request) bool {
return false
}
// Reader / stream body 通常不可重放,保守地不重试。
if r.config.Body.Reader != nil {
return isReplayableConfiguredBody(r.config.Body)
}
func hasReaderRequestBody(r *Request) bool {
if r == nil || r.config == nil {
return false
}
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
}
for _, f := range r.config.Body.Files {
if f.FileData != nil || f.FilePath == "" {
func isReplayableConfiguredBody(body BodyConfig) bool {
switch body.Mode {
case bodyModeReader:
return isReplayableBodyReader(body.Reader)
case bodyModeMultipart:
for _, file := range body.Files {
if file.FileData != nil || file.FilePath == "" {
return false
}
}
}
return true
}
func isReplayableBodyReader(reader io.Reader) bool {
switch reader.(type) {
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
return true
default:
return false
}
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false

244
review_regression_test.go Normal file
View File

@ -0,0 +1,244 @@
package starnet
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
firstTarget := net.JoinHostPort("127.0.0.2", port)
secondTarget := net.JoinHostPort("127.0.0.1", port)
var (
mu sync.Mutex
connectTargets []string
)
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
mu.Lock()
connectTargets = append(connectTargets, r.Host)
mu.Unlock()
if r.Host == firstTarget {
http.Error(w, "first target failed", http.StatusBadGateway)
return
}
targetConn, err := net.Dial("tcp", r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
defer proxyServer.Close()
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if len(connectTargets) != 2 {
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
}
if connectTargets[0] != firstTarget {
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
}
if connectTargets[1] != secondTarget {
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
}
}
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var (
mu sync.Mutex
dnsStartCount int
dnsDoneCount int
lastHost string
)
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
lastHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if lastHost != "localhost" {
t.Fatalf("lastHost=%q; want localhost", lastHost)
}
}
func TestRequestHeadersReturnsCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetHeader("X-Test", "one").
SetHost("origin.example")
headers := req.Headers()
headers.Set("X-Test", "two")
headers.Set("Host", "mutated.example")
if got := req.GetHeader("X-Test"); got != "one" {
t.Fatalf("request header=%q; want one", got)
}
if got := req.Host(); got != "origin.example" {
t.Fatalf("request host=%q; want origin.example", got)
}
}
func TestRequestCookiesIsolation(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
source := []*http.Cookie{{
Name: "session",
Value: "one",
Path: "/",
}}
req.SetCookies(source)
source[0].Value = "mutated-outside"
got := req.Cookies()
if len(got) != 1 || got[0].Value != "one" {
t.Fatalf("cookies after SetCookies=%v", got)
}
got[0].Value = "mutated-copy"
if latest := req.Cookies()[0].Value; latest != "one" {
t.Fatalf("internal cookie mutated via getter, got %q", latest)
}
cookie := &http.Cookie{Name: "auth", Value: "token"}
req.ResetCookies().AddCookie(cookie)
cookie.Value = "changed"
latest := req.Cookies()
if len(latest) != 1 || latest[0].Value != "token" {
t.Fatalf("cookies after AddCookie=%v", latest)
}
}
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var dnsStartCount int
var dnsDoneCount int
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
dnsStartCount++
},
DNSDone: func(info TraceDNSDoneInfo) {
dnsDoneCount++
},
}
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if dnsStartCount != 1 || dnsDoneCount != 1 {
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
}
}

View File

@ -104,7 +104,34 @@ func TestRequestLevelTLSOverride(t *testing.T) {
}
func TestRequestTls(t *testing.T) {
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
var requestCount int
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
switch requestCount {
case 1:
if r.Header.Get("Hello") != "" {
t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello"))
}
if auth := r.Header.Get("Authorization"); auth != "" {
t.Fatalf("unexpected authorization on first request: %q", auth)
}
case 2:
if got := r.Header.Get("Hello"); got != "world" {
t.Fatalf("hello header=%q; want world", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" {
t.Fatalf("authorization=%q; want bearer token", got)
}
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
localURL := httpsURLForHost(t, server, "localhost")
resp, err := NewSimpleRequest(localURL, "GET").
SetTLSConfig(&tls.Config{RootCAs: pool}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
@ -114,11 +141,13 @@ func TestRequestTls(t *testing.T) {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
resp, err = client.NewSimpleRequest(localURL, "GET",
WithTLSConfig(&tls.Config{RootCAs: pool}),
WithHeader("hello", "world"),
WithContext(context.Background()),
WithBearerToken("ddddddd")).Do()
@ -134,14 +163,24 @@ func TestRequestTls(t *testing.T) {
}
func TestTLSWithProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("proxied"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"),
WithProxy(proxy.URL),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
@ -152,10 +191,22 @@ func TestTLSWithProxyPath(t *testing.T) {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 {
t.Fatalf("proxy targets=%v; want 1 target", targets)
}
t.Log(resp.Status)
}
func TestTLSWithProxyBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
@ -163,9 +214,11 @@ func TestTLSWithProxyBug(t *testing.T) {
// 关键:使用 WithProxy 触发 needsDynamicTransport
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
@ -177,20 +230,30 @@ func TestTLSWithProxyBug(t *testing.T) {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("proxy targets=%v", targets)
}
t.Logf("Status: %s", resp.Status)
}
// 更精准的复现:直接测试有问题的分支
func TestTLSDialWithoutServerName(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
req, err := client.NewRequest("https://www.google.com", "GET",
req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET",
WithTimeout(10*time.Second),
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
@ -206,14 +269,21 @@ func TestTLSDialWithoutServerName(t *testing.T) {
// 最小复现:只要触发 needsDynamicTransport 即可
func TestMinimalTLSBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// WithDialTimeout 也会触发动态 transport
req, err := client.NewRequest("https://www.baidu.com", "GET",
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithDialTimeout(5*time.Second),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
@ -227,3 +297,40 @@ func TestMinimalTLSBug(t *testing.T) {
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
func TestTLSWithSOCKS5ProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newSOCKS5ProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET",
WithTimeout(10*time.Second),
WithProxy(proxy.URL()),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("socks5 targets=%v", targets)
}
t.Logf("Status: %s", resp.Status)
}

View File

@ -4,12 +4,13 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
"b612.me/starnet/internal/tlssniffercore"
)
// replayConn replays buffered bytes first, then reads from live conn.
@ -51,214 +52,35 @@ type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
if err != nil {
return SniffResult{}, err
}
return convertCoreSniffResult(res), nil
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
IsTLS: res.IsTLS,
Buffer: res.Buffer,
}
if isTLS {
out.ClientHello = meta
if res.ClientHello != nil {
out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
}
return out, nil
return out
}
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
if meta == nil {
return nil
}
header, complete := readTLSRecordHeader(r, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, nextOK := readTLSRecordHeader(r, buf)
if len(nextHeader) < 5 || !nextOK {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
if !nextBodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(r, buf, 5)
}
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
if n <= 0 {
return nil, true
}
tmp := make([]byte, n)
readN, err := io.ReadFull(r, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for i := 0; i+1 < cipherSuitesLen; i += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
return &ClientHelloMeta{
ServerName: meta.ServerName,
LocalAddr: meta.LocalAddr,
RemoteAddr: meta.RemoteAddr,
SupportedProtos: append([]string(nil), meta.SupportedProtos...),
SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
CipherSuites: append([]uint16(nil), meta.CipherSuites...),
}
}
@ -433,123 +255,11 @@ func (c *Conn) serverName() string {
}
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
applyServerTLSOverrides(out, selected)
return out
return tlssniffercore.ComposeServerTLSConfig(base, selected)
}
func applyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for k, v := range src.NameToCertificate {
m[k] = v
}
dst.NameToCertificate = m
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
tlssniffercore.ApplyServerTLSOverrides(dst, src)
}
func (c *Conn) IsTLS() bool {

340
trace.go Normal file
View File

@ -0,0 +1,340 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http/httptrace"
"sync/atomic"
"time"
)
type traceContextKey struct{}
// TraceHooks defines optional callbacks for network lifecycle events.
// Hooks may be called concurrently.
type TraceHooks struct {
GetConn func(TraceGetConnInfo)
GotConn func(TraceGotConnInfo)
PutIdleConn func(TracePutIdleConnInfo)
DNSStart func(TraceDNSStartInfo)
DNSDone func(TraceDNSDoneInfo)
ConnectStart func(TraceConnectStartInfo)
ConnectDone func(TraceConnectDoneInfo)
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
WroteHeaderField func(TraceWroteHeaderFieldInfo)
WroteHeaders func()
WroteRequest func(TraceWroteRequestInfo)
GotFirstResponseByte func()
RetryAttemptStart func(TraceRetryAttemptStartInfo)
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
RetryBackoff func(TraceRetryBackoffInfo)
}
type TraceGetConnInfo struct {
Addr string
}
type TraceGotConnInfo struct {
Conn net.Conn
Reused bool
WasIdle bool
IdleTime time.Duration
}
type TracePutIdleConnInfo struct {
Err error
}
type TraceDNSStartInfo struct {
Host string
}
type TraceDNSDoneInfo struct {
Addrs []net.IPAddr
Coalesced bool
Err error
}
type TraceConnectStartInfo struct {
Network string
Addr string
}
type TraceConnectDoneInfo struct {
Network string
Addr string
Err error
}
type TraceTLSHandshakeStartInfo struct {
Network string
Addr string
ServerName string
}
type TraceTLSHandshakeDoneInfo struct {
Network string
Addr string
ServerName string
ConnectionState tls.ConnectionState
Err error
}
type TraceWroteHeaderFieldInfo struct {
Key string
Values []string
}
type TraceWroteRequestInfo struct {
Err error
}
type TraceRetryAttemptStartInfo struct {
Attempt int
MaxAttempts int
}
type TraceRetryAttemptDoneInfo struct {
Attempt int
MaxAttempts int
StatusCode int
Err error
WillRetry bool
}
type TraceRetryBackoffInfo struct {
Attempt int
Delay time.Duration
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
}
func newTraceState(hooks *TraceHooks) *traceState {
if hooks == nil {
return nil
}
return &traceState{hooks: hooks}
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
}
return context.WithValue(ctx, traceContextKey{}, state)
}
func getTraceState(ctx context.Context) *traceState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(traceContextKey{}).(*traceState)
return state
}
func (t *traceState) needsHTTPTrace() bool {
if t == nil || t.hooks == nil {
return false
}
h := t.hooks
return h.GetConn != nil ||
h.GotConn != nil ||
h.PutIdleConn != nil ||
h.DNSStart != nil ||
h.DNSDone != nil ||
h.ConnectStart != nil ||
h.ConnectDone != nil ||
h.TLSHandshakeStart != nil ||
h.TLSHandshakeDone != nil ||
h.WroteHeaderField != nil ||
h.WroteHeaders != nil ||
h.WroteRequest != nil ||
h.GotFirstResponseByte != nil
}
func (t *traceState) clientTrace() *httptrace.ClientTrace {
if !t.needsHTTPTrace() {
return nil
}
h := t.hooks
trace := &httptrace.ClientTrace{}
if h.GetConn != nil {
trace.GetConn = func(hostPort string) {
h.GetConn(TraceGetConnInfo{Addr: hostPort})
}
}
if h.GotConn != nil {
trace.GotConn = func(info httptrace.GotConnInfo) {
h.GotConn(TraceGotConnInfo{
Conn: info.Conn,
Reused: info.Reused,
WasIdle: info.WasIdle,
IdleTime: info.IdleTime,
})
}
}
if h.PutIdleConn != nil {
trace.PutIdleConn = func(err error) {
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
}
}
if h.DNSStart != nil {
trace.DNSStart = func(info httptrace.DNSStartInfo) {
if t.usesManualDNS() {
return
}
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
}
}
if h.DNSDone != nil {
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
if t.usesManualDNS() {
return
}
h.DNSDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), info.Addrs...),
Coalesced: info.Coalesced,
Err: info.Err,
})
}
}
if h.ConnectStart != nil {
trace.ConnectStart = func(network, addr string) {
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
}
if h.ConnectDone != nil {
trace.ConnectDone = func(network, addr string, err error) {
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
}
if h.TLSHandshakeStart != nil {
trace.TLSHandshakeStart = func() {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
}
}
if h.TLSHandshakeDone != nil {
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
ConnectionState: state,
Err: err,
})
}
}
if h.WroteHeaderField != nil {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: value,
})
}
}
if h.WroteHeaders != nil {
trace.WroteHeaders = h.WroteHeaders
}
if h.WroteRequest != nil {
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
}
}
if h.GotFirstResponseByte != nil {
trace.GotFirstResponseByte = h.GotFirstResponseByte
}
return trace
}
func (t *traceState) markCustomTLS() {
if t == nil {
return
}
t.customTLS.Store(1)
}
func (t *traceState) usesCustomTLS() bool {
if t == nil {
return false
}
return t.customTLS.Load() != 0
}
func (t *traceState) beginManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(1)
}
func (t *traceState) endManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(-1)
}
func (t *traceState) usesManualDNS() bool {
if t == nil {
return false
}
return t.manualDNSRefs.Load() > 0
}
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
return
}
t.hooks.TLSHandshakeStart(info)
}
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
return
}
t.hooks.TLSHandshakeDone(info)
}
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
return
}
t.hooks.DNSStart(info)
}
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
return
}
t.hooks.DNSDone(info)
}
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
if hooks == nil || hooks.RetryAttemptStart == nil {
return
}
hooks.RetryAttemptStart(info)
}
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
if hooks == nil || hooks.RetryAttemptDone == nil {
return
}
hooks.RetryAttemptDone(info)
}
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
if hooks == nil || hooks.RetryBackoff == nil {
return
}
hooks.RetryBackoff(info)
}

324
trace_test.go Normal file
View File

@ -0,0 +1,324 @@
package starnet
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
)
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
var mu sync.Mutex
events := map[string]int{}
hooks := &TraceHooks{
GetConn: func(info TraceGetConnInfo) {
mu.Lock()
events["get_conn"]++
mu.Unlock()
},
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
events["got_conn"]++
mu.Unlock()
},
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
events["tls_start"]++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
events["tls_done"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected tls handshake error: %v", info.Err)
}
},
WroteHeaders: func() {
mu.Lock()
events["wrote_headers"]++
mu.Unlock()
},
WroteRequest: func(info TraceWroteRequestInfo) {
mu.Lock()
events["wrote_request"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected write error: %v", info.Err)
}
},
GotFirstResponseByte: func() {
mu.Lock()
events["first_byte"]++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
if events[key] == 0 {
t.Fatalf("expected trace event %q", key)
}
}
}
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
tlsStartCount := 0
tlsDoneCount := 0
var lastInfo TraceTLSHandshakeDoneInfo
hooks := &TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
tlsStartCount++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
tlsDoneCount++
lastInfo = info
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetDialTimeout(1500 * time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if tlsStartCount != 1 {
t.Fatalf("tlsStartCount=%d", tlsStartCount)
}
if tlsDoneCount != 1 {
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
}
if lastInfo.Err != nil {
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
}
if lastInfo.ConnectionState.Version == 0 {
t.Fatal("expected tls connection state")
}
}
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var mu sync.Mutex
dnsStartCount := 0
dnsDoneCount := 0
var dnsStartHost string
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
dnsStartHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(url, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if dnsStartHost != "trace.example.test" {
t.Fatalf("dnsStartHost=%q", dnsStartHost)
}
}
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
connectStartCount := 0
connectDoneCount := 0
hooks := &TraceHooks{
ConnectStart: func(info TraceConnectStartInfo) {
mu.Lock()
connectStartCount++
mu.Unlock()
},
ConnectDone: func(info TraceConnectDoneInfo) {
mu.Lock()
connectDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected connect error: %v", info.Err)
}
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(context.Background(), network, addr)
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if connectStartCount != 1 {
t.Fatalf("connectStartCount=%d", connectStartCount)
}
if connectDoneCount != 1 {
t.Fatalf("connectDoneCount=%d", connectDoneCount)
}
}
func TestTraceHooksRetryEvents(t *testing.T) {
var hits int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
if hits == 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
starts := 0
dones := 0
backoffs := 0
var finalDone TraceRetryAttemptDoneInfo
hooks := &TraceHooks{
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
mu.Lock()
starts++
mu.Unlock()
},
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
mu.Lock()
dones++
finalDone = info
mu.Unlock()
},
RetryBackoff: func(info TraceRetryBackoffInfo) {
mu.Lock()
backoffs++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if starts != 2 {
t.Fatalf("starts=%d", starts)
}
if dones != 2 {
t.Fatalf("dones=%d", dones)
}
if backoffs != 1 {
t.Fatalf("backoffs=%d", backoffs)
}
if finalDone.WillRetry {
t.Fatal("expected final attempt not to retry")
}
if finalDone.StatusCode != http.StatusOK {
t.Fatalf("final status=%d", finalDone.StatusCode)
}
}
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
var gotErr error
hooks := &TraceHooks{
DNSDone: func(info TraceDNSDoneInfo) {
gotErr = info.Err
},
}
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return nil, errors.New("lookup failed")
}).
SetTraceHooks(hooks).
Do()
if err == nil {
t.Fatal("expected request error")
}
if gotErr == nil || gotErr.Error() != "lookup failed" {
t.Fatalf("gotErr=%v", gotErr)
}
}

View File

@ -1,25 +1,78 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
const dynamicTransportCacheMaxEntries = 64
type dynamicTransportCacheKey struct {
proxyKey string
dialTimeout time.Duration
customIPs string
customDNS string
tlsServerName string
skipVerify bool
}
// Transport 自定义 Transport支持请求级配置
type Transport struct {
base *http.Transport
dynamicCache map[dynamicTransportCacheKey]*http.Transport
dynamicCacheOrder []dynamicTransportCacheKey
mu sync.RWMutex
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// 确保 base 已初始化
if t.base == nil {
t.mu.Lock()
if t.base == nil {
t.base = &http.Transport{
t.ensureBase()
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
traceState := getTraceState(req.Context())
execReq := req
execReqCtx := reqCtx
var targetAddrs []string
// 优先级1完全自定义的 transport
if execReqCtx.Transport != nil {
return execReqCtx.Transport.RoundTrip(execReq)
}
var err error
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
if err != nil {
return nil, err
}
// 优先级2需要动态配置
if needsDynamicTransport(execReqCtx) {
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
}
return dynamicTransport.RoundTrip(execReq)
}
// 优先级3使用基础 transport
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
}
return baseTransport.RoundTrip(execReq)
}
func newBaseHTTPTransport() *http.Transport {
return &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
@ -27,35 +80,141 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
t.mu.Unlock()
}
}
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
// 优先级1完全自定义的 transport
if reqCtx.Transport != nil {
return reqCtx.Transport.RoundTrip(req)
func (t *Transport) ensureBase() {
if t.base != nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
}
// 优先级2需要动态配置
if needsDynamicTransport(reqCtx) {
dynamicTransport := t.buildDynamicTransport(reqCtx)
return dynamicTransport.RoundTrip(req)
func (t *Transport) ensureBaseLocked() {
if t.base == nil {
t.base = newBaseHTTPTransport()
}
}
// 优先级3使用基础 transport
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
if key, ok := newDynamicTransportCacheKey(rc); ok {
return t.getOrCreateCachedDynamicTransport(key, rc)
}
return t.buildDynamicTransport(rc, traceState)
}
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base.RoundTrip(req)
if transport := t.dynamicCache[key]; transport != nil {
t.mu.RUnlock()
return transport
}
t.mu.RUnlock()
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
if transport := t.dynamicCache[key]; transport != nil {
return transport
}
transport := buildDynamicTransportFromBase(t.base, rc, nil)
if t.dynamicCache == nil {
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
}
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
oldestKey := t.dynamicCacheOrder[0]
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
oldest.CloseIdleConnections()
delete(t.dynamicCache, oldestKey)
}
}
t.dynamicCache[key] = transport
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
return transport
}
func (t *Transport) resetDynamicTransportCacheLocked() {
for _, key := range t.dynamicCacheOrder {
if transport := t.dynamicCache[key]; transport != nil {
transport.CloseIdleConnections()
}
}
t.dynamicCache = nil
t.dynamicCacheOrder = nil
}
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
if rc == nil {
return dynamicTransportCacheKey{}, false
}
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
return dynamicTransportCacheKey{}, false
}
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
return dynamicTransportCacheKey{}, false
}
key := dynamicTransportCacheKey{
proxyKey: normalizeProxyCacheKey(rc.Proxy),
dialTimeout: rc.DialTimeout,
customIPs: serializeTransportCacheList(rc.CustomIP),
customDNS: serializeTransportCacheList(rc.CustomDNS),
tlsServerName: effectiveTLSServerName(rc),
}
if rc.TLSConfig != nil {
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
}
return key, true
}
func normalizeProxyCacheKey(proxy string) string {
if proxy == "" {
return ""
}
proxyURL, err := parseProxyURL(proxy)
if err != nil {
return "\x00invalid:" + proxy
}
return proxyURL.String()
}
func serializeTransportCacheList(values []string) string {
if len(values) == 0 {
return ""
}
var builder strings.Builder
for _, value := range values {
builder.WriteString(value)
builder.WriteByte(0)
}
return builder.String()
}
func effectiveTLSServerName(rc *RequestContext) string {
if rc == nil {
return ""
}
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
return rc.TLSConfig.ServerName
}
return rc.TLSServerName
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
t.ensureBase()
t.mu.RLock()
transport := t.base.Clone()
baseTransport := t.base
t.mu.RUnlock()
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
}
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
transport := baseTransport.Clone()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
@ -64,15 +223,33 @@ func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := url.Parse(rc.Proxy)
if err == nil {
proxyURL, err := parseProxyURL(rc.Proxy)
if err != nil {
transport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, err
}
} else {
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
dialFn := rc.DialFn
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if traceState.hooks.ConnectStart != nil {
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
conn, err := dialFn(ctx, network, addr)
if traceState.hooks.ConnectDone != nil {
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
return conn, err
}
} else {
transport.DialContext = rc.DialFn
}
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
@ -93,5 +270,147 @@ func (t *Transport) Base() *http.Transport {
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.resetDynamicTransportCacheLocked()
t.mu.Unlock()
}
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
if req == nil || req.URL == nil || reqCtx == nil {
return req, reqCtx, nil, nil
}
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
return req, reqCtx, nil, nil
}
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
return req, reqCtx, nil, nil
}
host := req.URL.Hostname()
if host == "" {
return req, reqCtx, nil, nil
}
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
if err != nil {
return nil, nil, nil, err
}
if len(targetAddrs) == 0 {
return req, reqCtx, nil, nil
}
execReqCtx := *reqCtx
execReqCtx.CustomIP = nil
execReqCtx.CustomDNS = nil
execReqCtx.LookupIPFn = nil
if req.URL.Scheme == "https" {
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
execReqCtx.TLSConfigCacheable = true
}
}
execCtx := clearTargetResolutionContext(req.Context())
execReq := req.Clone(execCtx)
execReq.Host = req.Host
if len(targetAddrs) == 1 {
execReq.URL.Host = targetAddrs[0]
return execReq, &execReqCtx, nil, nil
}
return execReq, &execReqCtx, targetAddrs, nil
}
func clearTargetResolutionContext(ctx context.Context) context.Context {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
cloned := cloneRequestContext(rc)
cloned.CustomIP = nil
cloned.CustomDNS = nil
cloned.LookupIPFn = nil
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
}
}
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
return ctx
}
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
if serverName == "" {
return cfg
}
if cfg != nil {
if cfg.ServerName != "" {
return cfg
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
return &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
}
}
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
return rt.RoundTrip(baseReq)
}
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
targetAddrs = targetAddrs[:1]
}
var lastErr error
for _, targetAddr := range targetAddrs {
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
if err != nil {
return nil, err
}
resp, err := rt.RoundTrip(attemptReq)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
if req == nil {
return false
}
if !isIdempotentMethod(req.Method) {
return false
}
if req.Body == nil || req.Body == http.NoBody {
return true
}
return req.GetBody != nil
}
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
req := baseReq.Clone(baseReq.Context())
switch {
case baseReq.Body == nil || baseReq.Body == http.NoBody:
req.Body = baseReq.Body
case baseReq.GetBody != nil:
body, err := baseReq.GetBody()
if err != nil {
return nil, wrapError(err, "clone request body for resolved target")
}
req.Body = body
default:
req.Body = baseReq.Body
}
req.URL.Host = targetAddr
req.Host = baseReq.Host
return req, nil
}

224
transport_cache_test.go Normal file
View File

@ -0,0 +1,224 @@
package starnet
import (
"crypto/tls"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
if first != second {
t.Fatal("expected cached dynamic transport to be reused")
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "first.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "second.test",
}, nil)
if first == second {
t.Fatal("expected distinct tls server names to use different transports")
}
if got := len(transport.dynamicCache); got != 2 {
t.Fatalf("dynamic cache size=%d; want 2", got)
}
}
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
reqCtx := &RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
first := transport.getDynamicTransport(reqCtx, nil)
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected user tls config to bypass dynamic transport cache")
}
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0", got)
}
}
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
first := transport.getDynamicTransport(reqCtx, nil)
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
}
client.SetDefaultSkipTLSVerify(true)
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
}
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected cache reset after default tls change")
}
}
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
client := NewClientNoErr()
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
runRequest := func() bool {
var (
mu sync.Mutex
gotConn bool
reused bool
)
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetTraceHooks(&TraceHooks{
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
gotConn = true
reused = info.Reused
mu.Unlock()
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
mu.Lock()
defer mu.Unlock()
if !gotConn {
t.Fatal("expected GotConn trace event")
}
return reused
}
if runRequest() {
t.Fatal("first request unexpectedly reused a connection")
}
if !runRequest() {
t.Fatal("second request did not reuse cached dynamic transport connection")
}
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if execReq == req {
t.Fatal("expected cloned request for proxy target preparation")
}
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
}
if got := req.URL.Host; got != "proxy-single.test:8443" {
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
}
if len(targetAddrs) != 0 {
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
}
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
t.Fatal("expected synthesized tls config for single target proxy request")
}
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
}
}
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
t.Fatalf("execReq.URL.Host=%q; want original host", got)
}
if len(targetAddrs) != 2 {
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
}
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
}
}

View File

@ -53,6 +53,7 @@ type NetworkConfig struct {
type TLSConfig struct {
Config *tls.Config // TLS 配置
SkipVerify bool // 跳过证书验证
ServerName string // 显式 TLS ServerName/SNI 覆盖
}
// DNSConfig DNS 配置
@ -62,8 +63,19 @@ type DNSConfig struct {
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
}
type bodyMode uint8
const (
bodyModeUnset bodyMode = iota
bodyModeBytes
bodyModeReader
bodyModeForm
bodyModeMultipart
)
// BodyConfig 请求体配置
type BodyConfig struct {
Mode bodyMode // 当前 body 来源模式
Bytes []byte // 原始字节
Reader io.Reader // 数据流
FormData map[string][]string // 表单数据
@ -82,6 +94,7 @@ type RequestConfig struct {
// 其他配置
BasicAuth [2]string // Basic 认证
Host string // 显式 Host 头覆盖
ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
@ -104,6 +117,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
TLS: TLSConfig{
Config: cloneTLSConfig(c.TLS.Config),
SkipVerify: c.TLS.SkipVerify,
ServerName: c.TLS.ServerName,
},
DNS: DNSConfig{
CustomIP: cloneStringSlice(c.DNS.CustomIP),
@ -111,6 +125,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
LookupFunc: c.DNS.LookupFunc,
},
Body: BodyConfig{
Mode: c.Body.Mode,
Bytes: cloneBytes(c.Body.Bytes),
Reader: c.Body.Reader, // Reader 不可克隆
FormData: cloneStringMapSlice(c.Body.FormData),
@ -120,6 +135,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
Cookies: cloneCookies(c.Cookies),
Queries: cloneStringMapSlice(c.Queries),
BasicAuth: c.BasicAuth,
Host: c.Host,
ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength,
MaxRespBodyBytes: c.MaxRespBodyBytes,

View File

@ -101,24 +101,31 @@ func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
}
newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies {
newCookies[i] = &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: c.SameSite,
Raw: c.Raw,
Unparsed: append([]string(nil), c.Unparsed...),
}
newCookies[i] = cloneCookie(c)
}
return newCookies
}
func cloneCookie(cookie *http.Cookie) *http.Cookie {
if cookie == nil {
return nil
}
return &http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
Path: cookie.Path,
Domain: cookie.Domain,
Expires: cookie.Expires,
RawExpires: cookie.RawExpires,
MaxAge: cookie.MaxAge,
Secure: cookie.Secure,
HttpOnly: cookie.HttpOnly,
SameSite: cookie.SameSite,
Raw: cookie.Raw,
Unparsed: append([]string(nil), cookie.Unparsed...),
}
}
// cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil {
@ -171,8 +178,8 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
// copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if progress == nil {
return io.Copy(dst, src)
if ctx == nil {
ctx = context.Background()
}
var written int64
@ -190,9 +197,11 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if progress != nil {
// 同步调用进度回调(不使用 goroutine
progress(filename, written, total)
}
}
if ew != nil {
return written, ew
}
@ -202,8 +211,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
}
if err != nil {
if err == io.EOF {
if progress != nil {
// 最后一次进度回调
progress(filename, written, total)
}
return written, nil
}
return written, err