fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界
- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
This commit is contained in:
parent
9ac9b65bc5
commit
732e81316c
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@
|
|||||||
.sentrux/
|
.sentrux/
|
||||||
agent_readme.md
|
agent_readme.md
|
||||||
target.md
|
target.md
|
||||||
|
agents.md
|
||||||
|
.codex
|
||||||
89
README.md
89
README.md
@ -1,58 +1,58 @@
|
|||||||
# starnet
|
# starnet
|
||||||
|
|
||||||
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities.
|
`starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。
|
||||||
|
|
||||||
## Highlights
|
## 功能概览
|
||||||
|
|
||||||
- Request-level timeout by context (without mutating shared `http.Client` timeout)
|
- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
|
||||||
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config
|
- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
|
||||||
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses
|
- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
|
||||||
- Response body safety guard via max body bytes limit
|
- 响应体大小限制,避免一次性读取过大内容
|
||||||
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`)
|
- 错误分类辅助:`ClassifyError`、`IsTimeout`、`IsDNS`、`IsTLS`、`IsProxy`、`IsCanceled`
|
||||||
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios
|
- TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景
|
||||||
- ICMP ping with IPv4/IPv6 target handling and option-based probing API
|
- ICMP Ping,支持 IPv4 / IPv6 目标和选项化探测
|
||||||
|
|
||||||
## Main Features
|
## 主要能力
|
||||||
|
|
||||||
### HTTP Client and Request
|
### HTTP 客户端与请求构建
|
||||||
|
|
||||||
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods
|
- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
|
||||||
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect`
|
- 支持 `Get`、`Post`、`Put`、`Delete`、`Head`、`Patch`、`Options`、`Trace`、`Connect`
|
||||||
- Request body helpers: JSON, form data, multipart file upload, stream body
|
- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
|
||||||
- Header/cookie/query helpers with defensive copy on key setters
|
- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
|
||||||
- Request cloning for safe reuse in concurrent or variant calls
|
- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
|
||||||
|
|
||||||
### Timeout and Retry
|
### 超时与重试
|
||||||
|
|
||||||
- Request timeout is applied by context deadline, not global client timeout
|
- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
|
||||||
- Retry supports:
|
- 重试支持:
|
||||||
- max attempts
|
- 最大尝试次数
|
||||||
- backoff factor/base/max
|
- 基础退避、最大退避和退避因子
|
||||||
- jitter
|
- 抖动比例
|
||||||
- retry status whitelist
|
- 可重试状态码集合
|
||||||
- idempotent-only guard
|
- 仅幂等方法重试
|
||||||
- custom retry-on-error callback
|
- 自定义错误判定函数
|
||||||
- Retry keeps original request pointer in final response for consistency
|
- 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用
|
||||||
|
|
||||||
### Response Handling
|
### 响应处理
|
||||||
|
|
||||||
- `Bytes/String/JSON/Reader` helpers
|
- 提供 `Bytes`、`String`、`JSON`、`Reader` 等响应体读取接口
|
||||||
- optional auto-fetch mode
|
- 支持自动预取响应体
|
||||||
- configurable max response body bytes to prevent oversized reads
|
- 支持按字节数限制响应体读取上限
|
||||||
|
|
||||||
### Ping Module
|
### Ping 模块
|
||||||
|
|
||||||
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable`
|
- 提供 `Ping`、`PingWithContext`、`Pingable` 以及兼容函数 `IsIpPingable`
|
||||||
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size
|
- `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数
|
||||||
- explicit error semantics for permission/protocol/timeout/resolve failures
|
- 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义
|
||||||
|
|
||||||
## Install
|
## 安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go get b612.me/starnet
|
go get b612.me/starnet
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Example
|
## 快速示例
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
@ -94,13 +94,18 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Stability Notes
|
## 行为说明
|
||||||
|
|
||||||
- Raw ICMP ping may require elevated privileges on some systems.
|
- `NewClient`、`NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
|
||||||
- Integration tests that rely on external network are environment-dependent.
|
- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`。
|
||||||
|
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader` 或 `WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
|
||||||
|
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
|
||||||
|
|
||||||
## License
|
## 稳定性说明
|
||||||
|
|
||||||
This project is licensed under the Apache License 2.0.
|
- 原始 ICMP Ping 在部分系统上需要额外权限。
|
||||||
See [LICENSE](./LICENSE).
|
- 依赖外部网络环境的集成测试结果可能受运行环境影响。
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。
|
||||||
|
|||||||
@ -148,6 +148,33 @@ func BenchmarkRequestCreation(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestPrepareDefaultPath(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET")
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
b.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestPrepareDynamicPath(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET",
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
b.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkResponseBodyRead(b *testing.B) {
|
func BenchmarkResponseBodyRead(b *testing.B) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("test response data"))
|
w.Write([]byte("test response data"))
|
||||||
|
|||||||
24
client.go
24
client.go
@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client HTTP 客户端封装
|
// Client HTTP 客户端封装
|
||||||
@ -19,14 +18,7 @@ type Client struct {
|
|||||||
// NewClient 创建新的 Client
|
// NewClient 创建新的 Client
|
||||||
func NewClient(opts ...RequestOpt) (*Client, error) {
|
func NewClient(opts ...RequestOpt) (*Client, error) {
|
||||||
// 创建基础 Transport
|
// 创建基础 Transport
|
||||||
baseTransport := &http.Transport{
|
baseTransport := newBaseHTTPTransport()
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: &Transport{base: baseTransport},
|
Transport: &Transport{base: baseTransport},
|
||||||
@ -40,6 +32,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err, "create client")
|
return nil, wrapError(err, "create client")
|
||||||
}
|
}
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, wrapError(req.err, "create client")
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// 如果选项中有自定义配置,应用到 httpClient
|
// 如果选项中有自定义配置,应用到 httpClient
|
||||||
@ -61,7 +56,9 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientNoErr 创建新的 Client(忽略错误)
|
// NewClientNoErr 创建新的 Client(忽略错误)。
|
||||||
|
// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象;
|
||||||
|
// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。
|
||||||
func NewClientNoErr(opts ...RequestOpt) *Client {
|
func NewClientNoErr(opts ...RequestOpt) *Client {
|
||||||
client, _ := NewClient(opts...)
|
client, _ := NewClient(opts...)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
@ -172,11 +169,13 @@ func (c *Client) Clone() *Client {
|
|||||||
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
||||||
if transport, ok := c.client.Transport.(*Transport); ok {
|
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||||
transport.mu.Lock()
|
transport.mu.Lock()
|
||||||
|
transport.ensureBaseLocked()
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
transport.base.TLSClientConfig = tlsConfig.Clone()
|
transport.base.TLSClientConfig = tlsConfig.Clone()
|
||||||
} else {
|
} else {
|
||||||
transport.base.TLSClientConfig = nil
|
transport.base.TLSClientConfig = nil
|
||||||
}
|
}
|
||||||
|
transport.resetDynamicTransportCacheLocked()
|
||||||
transport.mu.Unlock()
|
transport.mu.Unlock()
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
@ -186,12 +185,14 @@ func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
|||||||
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
|
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
|
||||||
if transport, ok := c.client.Transport.(*Transport); ok {
|
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||||
transport.mu.Lock()
|
transport.mu.Lock()
|
||||||
|
transport.ensureBaseLocked()
|
||||||
if transport.base.TLSClientConfig == nil {
|
if transport.base.TLSClientConfig == nil {
|
||||||
transport.base.TLSClientConfig = &tls.Config{}
|
transport.base.TLSClientConfig = &tls.Config{}
|
||||||
} else {
|
} else {
|
||||||
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
|
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
|
||||||
}
|
}
|
||||||
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||||
|
transport.resetDynamicTransportCacheLocked()
|
||||||
transport.mu.Unlock()
|
transport.mu.Unlock()
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
@ -227,6 +228,9 @@ func (c *Client) NewRequestWithContext(ctx context.Context, url, method string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, req.err
|
||||||
|
}
|
||||||
|
|
||||||
req.client = c
|
req.client = c
|
||||||
req.httpClient = c.client
|
req.httpClient = c.client
|
||||||
|
|||||||
167
context.go
167
context.go
@ -14,6 +14,8 @@ type contextKey int
|
|||||||
const (
|
const (
|
||||||
ctxKeyTransport contextKey = iota
|
ctxKeyTransport contextKey = iota
|
||||||
ctxKeyTLSConfig
|
ctxKeyTLSConfig
|
||||||
|
ctxKeyTLSConfigCacheable
|
||||||
|
ctxKeyTLSServerName
|
||||||
ctxKeyProxy
|
ctxKeyProxy
|
||||||
ctxKeyCustomIP
|
ctxKeyCustomIP
|
||||||
ctxKeyCustomDNS
|
ctxKeyCustomDNS
|
||||||
@ -21,58 +23,95 @@ const (
|
|||||||
ctxKeyTimeout
|
ctxKeyTimeout
|
||||||
ctxKeyLookupIP
|
ctxKeyLookupIP
|
||||||
ctxKeyDialFunc
|
ctxKeyDialFunc
|
||||||
|
ctxKeyRequestContext
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestContext 从 context 中提取的请求配置
|
// RequestContext 从 context 中提取的请求配置
|
||||||
type RequestContext struct {
|
type RequestContext struct {
|
||||||
Transport *http.Transport
|
Transport *http.Transport
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
Proxy string
|
TLSConfigCacheable bool
|
||||||
CustomIP []string
|
TLSServerName string
|
||||||
CustomDNS []string
|
Proxy string
|
||||||
DialTimeout time.Duration
|
CustomIP []string
|
||||||
Timeout time.Duration
|
CustomDNS []string
|
||||||
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
DialTimeout time.Duration
|
||||||
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
Timeout time.Duration
|
||||||
|
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
||||||
|
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var emptyRequestContext = &RequestContext{}
|
||||||
|
|
||||||
// getRequestContext 从 context 中提取请求配置
|
// getRequestContext 从 context 中提取请求配置
|
||||||
func getRequestContext(ctx context.Context) *RequestContext {
|
func getRequestContext(ctx context.Context) *RequestContext {
|
||||||
rc := &RequestContext{}
|
if v := ctx.Value(ctxKeyRequestContext); v != nil {
|
||||||
|
if rc, ok := v.(*RequestContext); ok && rc != nil {
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rc *RequestContext
|
||||||
|
ensure := func() *RequestContext {
|
||||||
|
if rc == nil {
|
||||||
|
rc = &RequestContext{}
|
||||||
|
}
|
||||||
|
return rc
|
||||||
|
}
|
||||||
if v := ctx.Value(ctxKeyTransport); v != nil {
|
if v := ctx.Value(ctxKeyTransport); v != nil {
|
||||||
rc.Transport, _ = v.(*http.Transport)
|
ensure().Transport, _ = v.(*http.Transport)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
||||||
rc.TLSConfig, _ = v.(*tls.Config)
|
ensure().TLSConfig, _ = v.(*tls.Config)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
|
||||||
|
ensure().TLSConfigCacheable, _ = v.(bool)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
|
||||||
|
ensure().TLSServerName, _ = v.(string)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyProxy); v != nil {
|
if v := ctx.Value(ctxKeyProxy); v != nil {
|
||||||
rc.Proxy, _ = v.(string)
|
ensure().Proxy, _ = v.(string)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
||||||
rc.CustomIP, _ = v.([]string)
|
ensure().CustomIP, _ = v.([]string)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
||||||
rc.CustomDNS, _ = v.([]string)
|
ensure().CustomDNS, _ = v.([]string)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
||||||
rc.DialTimeout, _ = v.(time.Duration)
|
ensure().DialTimeout, _ = v.(time.Duration)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
||||||
rc.Timeout, _ = v.(time.Duration)
|
ensure().Timeout, _ = v.(time.Duration)
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
||||||
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
||||||
}
|
}
|
||||||
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
||||||
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
||||||
|
}
|
||||||
|
if rc == nil {
|
||||||
|
return emptyRequestContext
|
||||||
}
|
}
|
||||||
|
|
||||||
return rc
|
return rc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneRequestContext(rc *RequestContext) *RequestContext {
|
||||||
|
if rc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *rc
|
||||||
|
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
|
||||||
|
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
// needsDynamicTransport 判断是否需要动态 Transport
|
// needsDynamicTransport 判断是否需要动态 Transport
|
||||||
func needsDynamicTransport(rc *RequestContext) bool {
|
func needsDynamicTransport(rc *RequestContext) bool {
|
||||||
|
if rc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return rc.Transport != nil ||
|
return rc.Transport != nil ||
|
||||||
rc.TLSConfig != nil ||
|
rc.TLSConfig != nil ||
|
||||||
rc.Proxy != "" ||
|
rc.Proxy != "" ||
|
||||||
@ -83,63 +122,67 @@ func needsDynamicTransport(rc *RequestContext) bool {
|
|||||||
rc.LookupIPFn != nil
|
rc.LookupIPFn != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// injectRequestConfig 将请求配置注入到 context
|
func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
|
||||||
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
|
if config == nil {
|
||||||
execCtx := ctx
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := &RequestContext{
|
||||||
|
DialTimeout: config.Network.DialTimeout,
|
||||||
|
Timeout: config.Network.Timeout,
|
||||||
|
}
|
||||||
|
|
||||||
// 处理 TLS 配置
|
// 处理 TLS 配置
|
||||||
var tlsConfig *tls.Config
|
var tlsConfig *tls.Config
|
||||||
|
tlsConfigCacheable := false
|
||||||
|
|
||||||
if config.TLS.Config != nil {
|
if config.TLS.Config != nil {
|
||||||
tlsConfig = config.TLS.Config.Clone()
|
tlsConfig = config.TLS.Config.Clone()
|
||||||
if config.TLS.SkipVerify {
|
} else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
|
||||||
tlsConfig.InsecureSkipVerify = true
|
|
||||||
}
|
|
||||||
} else if config.TLS.SkipVerify {
|
|
||||||
tlsConfig = &tls.Config{
|
tlsConfig = &tls.Config{
|
||||||
NextProtos: []string{"h2", "http/1.1"},
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}
|
}
|
||||||
|
tlsConfigCacheable = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.TLS.SkipVerify && tlsConfig != nil {
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
if config.TLS.ServerName != "" && tlsConfig != nil {
|
||||||
|
tlsConfig.ServerName = config.TLS.ServerName
|
||||||
|
}
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
|
rc.TLSConfig = tlsConfig
|
||||||
|
rc.TLSConfigCacheable = tlsConfigCacheable
|
||||||
|
}
|
||||||
|
if config.TLS.ServerName != "" {
|
||||||
|
rc.TLSServerName = config.TLS.ServerName
|
||||||
|
} else if defaultTLSServerName != "" {
|
||||||
|
rc.TLSServerName = defaultTLSServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注入代理
|
rc.Proxy = config.Network.Proxy
|
||||||
if config.Network.Proxy != "" {
|
rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
|
rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
|
||||||
}
|
rc.LookupIPFn = config.DNS.LookupFunc
|
||||||
|
rc.DialFn = config.Network.DialFunc
|
||||||
|
|
||||||
// 注入自定义 IP
|
|
||||||
if len(config.DNS.CustomIP) > 0 {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入自定义 DNS
|
|
||||||
if len(config.DNS.CustomDNS) > 0 {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 总是注入 DialTimeout(与原始代码一致)
|
|
||||||
if config.Network.DialTimeout > 0 {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入 DNS 解析函数
|
|
||||||
if config.DNS.LookupFunc != nil {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入 Dial 函数
|
|
||||||
if config.Network.DialFunc != nil {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入自定义 Transport
|
|
||||||
if config.CustomTransport && config.Transport != nil {
|
if config.CustomTransport && config.Transport != nil {
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
|
rc.Transport = config.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
return execCtx
|
if !needsDynamicTransport(rc) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectRequestConfig 将请求配置注入到 context
|
||||||
|
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
|
||||||
|
rc := buildRequestContext(config, defaultTLSServerName)
|
||||||
|
if rc == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, ctxKeyRequestContext, rc)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,3 +57,55 @@ func TestSetFormDataDefensiveCopy(t *testing.T) {
|
|||||||
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithBodyDefensiveCopy(t *testing.T) {
|
||||||
|
body := []byte("hello")
|
||||||
|
|
||||||
|
req, err := NewRequest("http://example.com", "POST", WithBody(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'j'
|
||||||
|
if string(req.config.Body.Bytes) != "hello" {
|
||||||
|
t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithFormDataDefensiveCopy(t *testing.T) {
|
||||||
|
form := map[string][]string{
|
||||||
|
"name": []string{"alice"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := NewRequest("http://example.com", "POST", WithFormData(form))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
form["name"][0] = "bob"
|
||||||
|
form["name"] = append(form["name"], "carol")
|
||||||
|
got := req.config.Body.FormData["name"]
|
||||||
|
if len(got) != 1 || got[0] != "alice" {
|
||||||
|
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCustomIPDefensiveCopy(t *testing.T) {
|
||||||
|
ips := []string{"1.1.1.1", "8.8.8.8"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips)
|
||||||
|
|
||||||
|
ips[0] = "9.9.9.9"
|
||||||
|
if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" {
|
||||||
|
t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCustomDNSDefensiveCopy(t *testing.T) {
|
||||||
|
servers := []string{"8.8.8.8", "1.1.1.1"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers)
|
||||||
|
|
||||||
|
servers[0] = "9.9.9.9"
|
||||||
|
if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" {
|
||||||
|
t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
176
dialer.go
176
dialer.go
@ -9,10 +9,100 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.beginManualDNS()
|
||||||
|
defer traceState.endManualDNS()
|
||||||
|
traceState.dnsStart(TraceDNSStartInfo{Host: host})
|
||||||
|
}
|
||||||
|
ipAddrs, err := lookup()
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.dnsDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: append([]net.IPAddr(nil), ipAddrs...),
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return ipAddrs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = &RequestContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var addrs []string
|
||||||
|
|
||||||
|
if len(reqCtx.CustomIP) > 0 {
|
||||||
|
for _, ip := range reqCtx.CustomIP {
|
||||||
|
addrs = append(addrs, joinResolvedHostPort(ip, port))
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ipAddrs []net.IPAddr
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if reqCtx.LookupIPFn != nil {
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return reqCtx.LookupIPFn(ctx, host)
|
||||||
|
})
|
||||||
|
} else if len(reqCtx.CustomDNS) > 0 {
|
||||||
|
dialTimeout := reqCtx.DialTimeout
|
||||||
|
if dialTimeout == 0 {
|
||||||
|
dialTimeout = DefaultDialTimeout
|
||||||
|
}
|
||||||
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||||
|
resolver := &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var lastErr error
|
||||||
|
for _, dnsServer := range reqCtx.CustomDNS {
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return resolver.LookupIPAddr(ctx, host)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "lookup ip")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipAddr := range ipAddrs {
|
||||||
|
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinResolvedHostPort(host, port string) string {
|
||||||
|
if port == "" {
|
||||||
|
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
|
||||||
|
return "[" + host + "]"
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
||||||
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
// 提取配置
|
// 提取配置
|
||||||
reqCtx := getRequestContext(ctx)
|
reqCtx := getRequestContext(ctx)
|
||||||
|
traceState := getTraceState(ctx)
|
||||||
|
|
||||||
dialTimeout := reqCtx.DialTimeout
|
dialTimeout := reqCtx.DialTimeout
|
||||||
if dialTimeout == 0 {
|
if dialTimeout == 0 {
|
||||||
@ -25,52 +115,9 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
|
|||||||
return nil, wrapError(err, "split host port")
|
return nil, wrapError(err, "split host port")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 IP 地址列表
|
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
|
||||||
var addrs []string
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// 优先级1:直接指定的 IP
|
|
||||||
if len(reqCtx.CustomIP) > 0 {
|
|
||||||
for _, ip := range reqCtx.CustomIP {
|
|
||||||
addrs = append(addrs, net.JoinHostPort(ip, port))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 优先级2:DNS 解析
|
|
||||||
var ipAddrs []net.IPAddr
|
|
||||||
|
|
||||||
// 使用自定义解析函数
|
|
||||||
if reqCtx.LookupIPFn != nil {
|
|
||||||
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
|
|
||||||
} else if len(reqCtx.CustomDNS) > 0 {
|
|
||||||
// 使用自定义 DNS 服务器
|
|
||||||
dialer := &net.Dialer{Timeout: dialTimeout}
|
|
||||||
resolver := &net.Resolver{
|
|
||||||
PreferGo: true,
|
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
var lastErr error
|
|
||||||
for _, dnsServer := range reqCtx.CustomDNS {
|
|
||||||
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
return nil, lastErr
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
|
|
||||||
} else {
|
|
||||||
// 使用默认解析器
|
|
||||||
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapError(err, "lookup ip")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ipAddr := range ipAddrs {
|
|
||||||
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 尝试连接所有地址
|
// 尝试连接所有地址
|
||||||
@ -103,13 +150,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
|
|
||||||
// 提取 TLS 配置
|
// 提取 TLS 配置
|
||||||
reqCtx := getRequestContext(ctx)
|
reqCtx := getRequestContext(ctx)
|
||||||
|
traceState := getTraceState(ctx)
|
||||||
tlsConfig := reqCtx.TLSConfig
|
tlsConfig := reqCtx.TLSConfig
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
tlsConfig = &tls.Config{}
|
tlsConfig = &tls.Config{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify,自动设置
|
serverName := tlsConfig.ServerName
|
||||||
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
if serverName == "" {
|
||||||
|
serverName = reqCtx.TLSServerName
|
||||||
|
}
|
||||||
|
if serverName == "" && !tlsConfig.InsecureSkipVerify {
|
||||||
host, _, err := net.SplitHostPort(addr)
|
host, _, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
||||||
@ -118,8 +169,19 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
host = addr
|
host = addr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
serverName = host
|
||||||
|
}
|
||||||
|
if serverName != "" && tlsConfig.ServerName != serverName {
|
||||||
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
||||||
tlsConfig.ServerName = host
|
tlsConfig.ServerName = serverName
|
||||||
|
}
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.markCustomTLS()
|
||||||
|
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行 TLS 握手
|
// 执行 TLS 握手
|
||||||
@ -130,9 +192,25 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
|
|
||||||
tlsConn := tls.Client(conn, tlsConfig)
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, wrapError(err, "tls handshake")
|
return nil, wrapError(err, "tls handshake")
|
||||||
}
|
}
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
ConnectionState: tlsConn.ConnectionState(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return tlsConn, nil
|
return tlsConn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
144
dynamic_transport_benchmark_test.go
Normal file
144
dynamic_transport_benchmark_test.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIP(b *testing.B) {
|
||||||
|
server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"}))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(b, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithProxy(proxy.URL),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-user-tls.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkTargetURL(tb testing.TB, rawURL, host string) string {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("url.Parse() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := parsed.Port()
|
||||||
|
if port == "" {
|
||||||
|
switch parsed.Scheme {
|
||||||
|
case "https":
|
||||||
|
port = "443"
|
||||||
|
default:
|
||||||
|
port = "80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathWithQuery(path, rawQuery string) string {
|
||||||
|
if path == "" {
|
||||||
|
path = "/"
|
||||||
|
}
|
||||||
|
if rawQuery == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return path + "?" + rawQuery
|
||||||
|
}
|
||||||
150
host_tls_regression_test.go
Normal file
150
host_tls_regression_test.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||||||
|
cfg := &tls.Config{}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSConfig(cfg).
|
||||||
|
SetURL("https://other.example")
|
||||||
|
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("unexpected request error: %v", req.Err())
|
||||||
|
}
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||||||
|
cfg := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSConfig(cfg).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if rc.TLSConfig == nil {
|
||||||
|
t.Fatal("expected injected tls config")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == cfg {
|
||||||
|
t.Fatal("expected injected tls config to be cloned")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig.ServerName != "override.example" {
|
||||||
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if rc.TLSConfig == nil {
|
||||||
|
t.Fatal("expected injected tls config")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig.ServerName != "override.example" {
|
||||||
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet)
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
|
||||||
|
t.Fatalf("unexpected request context injection: %#v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if needsDynamicTransport(rc) {
|
||||||
|
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "" {
|
||||||
|
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := req.execCtx.Value(ctxKeyRequestContext)
|
||||||
|
rc, ok := raw.(*RequestContext)
|
||||||
|
if !ok || rc == nil {
|
||||||
|
t.Fatalf("expected aggregated request context, got %#v", raw)
|
||||||
|
}
|
||||||
|
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
|
||||||
|
t.Fatalf("custom ip=%v", rc.CustomIP)
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
|
||||||
|
t.Fatal("expected tls config with skip verify")
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "example.com" {
|
||||||
|
t.Fatalf("default tls server name=%q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Host != "override.example" {
|
||||||
|
t.Fatalf("host=%q", r.Host)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetHost("override.example").
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithHostOverridesRequestHost(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Host != "option.example" {
|
||||||
|
t.Fatalf("host=%q", r.Host)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := resp.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer got.Close()
|
||||||
|
}
|
||||||
230
internal/pingcore/core.go
Normal file
230
internal/pingcore/core.go
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
package pingcore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const icmpHeaderLen = 8
|
||||||
|
|
||||||
|
type ICMP struct {
|
||||||
|
Type uint8
|
||||||
|
Code uint8
|
||||||
|
CheckSum uint16
|
||||||
|
Identifier uint16
|
||||||
|
SequenceNum uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Count int
|
||||||
|
Timeout time.Duration
|
||||||
|
Interval time.Duration
|
||||||
|
Deadline time.Time
|
||||||
|
PreferIPv4 bool
|
||||||
|
PreferIPv6 bool
|
||||||
|
SourceIP net.IP
|
||||||
|
PayloadSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Duration time.Duration
|
||||||
|
RecvCount int
|
||||||
|
RemoteIP string
|
||||||
|
}
|
||||||
|
|
||||||
|
var identifierSeed uint32
|
||||||
|
|
||||||
|
func NextIdentifier() uint16 {
|
||||||
|
pid := uint32(os.Getpid() & 0xffff)
|
||||||
|
n := atomic.AddUint32(&identifierSeed, 1)
|
||||||
|
return uint16((pid + n) & 0xffff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Payload(size int) []byte {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := make([]byte, size)
|
||||||
|
for index := 0; index < len(payload); index++ {
|
||||||
|
payload[index] = byte(index)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
||||||
|
icmp := ICMP{
|
||||||
|
Type: typ,
|
||||||
|
Code: 0,
|
||||||
|
CheckSum: 0,
|
||||||
|
Identifier: identifier,
|
||||||
|
SequenceNum: seq,
|
||||||
|
}
|
||||||
|
buf := MarshalPacket(icmp, payload)
|
||||||
|
icmp.CheckSum = Checksum(buf)
|
||||||
|
return icmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func Checksum(data []byte) uint16 {
|
||||||
|
var (
|
||||||
|
sum uint32
|
||||||
|
length = len(data)
|
||||||
|
index int
|
||||||
|
)
|
||||||
|
for length > 1 {
|
||||||
|
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
||||||
|
index += 2
|
||||||
|
length -= 2
|
||||||
|
}
|
||||||
|
if length > 0 {
|
||||||
|
sum += uint32(data[index]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return uint16(^sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Marshal(icmp ICMP) []byte {
|
||||||
|
return MarshalPacket(icmp, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPacket(icmp ICMP, payload []byte) []byte {
|
||||||
|
buf := make([]byte, icmpHeaderLen+len(payload))
|
||||||
|
buf[0] = icmp.Type
|
||||||
|
buf[1] = icmp.Code
|
||||||
|
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
|
||||||
|
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
|
||||||
|
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
|
||||||
|
copy(buf[icmpHeaderLen:], payload)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
||||||
|
for _, offset := range CandidateICMPOffsets(packet, family) {
|
||||||
|
if offset < 0 || offset+icmpHeaderLen > len(packet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if packet[offset] != expectedType || packet[offset+1] != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func CandidateICMPOffsets(packet []byte, family int) []int {
|
||||||
|
offsets := []int{0}
|
||||||
|
if len(packet) == 0 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
version := packet[0] >> 4
|
||||||
|
if version == 4 && len(packet) >= 20 {
|
||||||
|
ihl := int(packet[0]&0x0f) * 4
|
||||||
|
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
|
||||||
|
offsets = append(offsets, ihl)
|
||||||
|
}
|
||||||
|
} else if version == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == 4 && len(packet) >= 20+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 20)
|
||||||
|
}
|
||||||
|
if family == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
return DedupOffsets(offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DedupOffsets(offsets []int) []int {
|
||||||
|
if len(offsets) <= 1 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
seen := make(map[int]struct{}, len(offsets))
|
||||||
|
out := make([]int, 0, len(offsets))
|
||||||
|
for _, offset := range offsets {
|
||||||
|
if _, ok := seen[offset]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[offset] = struct{}{}
|
||||||
|
out = append(out, offset)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
||||||
|
if parsed := net.ParseIP(host); parsed != nil {
|
||||||
|
return []*net.IPAddr{{IP: parsed}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var targets []*net.IPAddr
|
||||||
|
var err4 error
|
||||||
|
var err6 error
|
||||||
|
|
||||||
|
if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil {
|
||||||
|
targets = append(targets, ip4)
|
||||||
|
} else {
|
||||||
|
err4 = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil {
|
||||||
|
targets = append(targets, ip6)
|
||||||
|
} else {
|
||||||
|
err6 = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(targets) > 0 {
|
||||||
|
return OrderTargets(targets, preferIPv4, preferIPv6), nil
|
||||||
|
}
|
||||||
|
if err4 != nil {
|
||||||
|
return nil, err4
|
||||||
|
}
|
||||||
|
if err6 != nil {
|
||||||
|
return nil, err6
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
||||||
|
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
|
||||||
|
return targets
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := make([]*net.IPAddr, 0, len(targets))
|
||||||
|
if preferIPv4 {
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
123
internal/tlssniffercore/config.go
Normal file
123
internal/tlssniffercore/config.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package tlssniffercore
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
||||||
|
if base == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
out := base.Clone()
|
||||||
|
ApplyServerTLSOverrides(out, selected)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyServerTLSOverrides(dst, src *tls.Config) {
|
||||||
|
if dst == nil || src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if src.Rand != nil {
|
||||||
|
dst.Rand = src.Rand
|
||||||
|
}
|
||||||
|
if src.Time != nil {
|
||||||
|
dst.Time = src.Time
|
||||||
|
}
|
||||||
|
if len(src.Certificates) > 0 {
|
||||||
|
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
|
||||||
|
}
|
||||||
|
if len(src.NameToCertificate) > 0 {
|
||||||
|
copied := make(map[string]*tls.Certificate, len(src.NameToCertificate))
|
||||||
|
for name, cert := range src.NameToCertificate {
|
||||||
|
copied[name] = cert
|
||||||
|
}
|
||||||
|
dst.NameToCertificate = copied
|
||||||
|
}
|
||||||
|
if src.GetCertificate != nil {
|
||||||
|
dst.GetCertificate = src.GetCertificate
|
||||||
|
}
|
||||||
|
if src.GetClientCertificate != nil {
|
||||||
|
dst.GetClientCertificate = src.GetClientCertificate
|
||||||
|
}
|
||||||
|
if src.GetConfigForClient != nil {
|
||||||
|
dst.GetConfigForClient = src.GetConfigForClient
|
||||||
|
}
|
||||||
|
if src.VerifyPeerCertificate != nil {
|
||||||
|
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
|
||||||
|
}
|
||||||
|
if src.VerifyConnection != nil {
|
||||||
|
dst.VerifyConnection = src.VerifyConnection
|
||||||
|
}
|
||||||
|
if src.RootCAs != nil {
|
||||||
|
dst.RootCAs = src.RootCAs
|
||||||
|
}
|
||||||
|
if len(src.NextProtos) > 0 {
|
||||||
|
dst.NextProtos = append([]string(nil), src.NextProtos...)
|
||||||
|
}
|
||||||
|
if src.ServerName != "" {
|
||||||
|
dst.ServerName = src.ServerName
|
||||||
|
}
|
||||||
|
if src.ClientAuth > dst.ClientAuth {
|
||||||
|
dst.ClientAuth = src.ClientAuth
|
||||||
|
}
|
||||||
|
if src.ClientCAs != nil {
|
||||||
|
dst.ClientCAs = src.ClientCAs
|
||||||
|
}
|
||||||
|
if src.InsecureSkipVerify {
|
||||||
|
dst.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
if len(src.CipherSuites) > 0 {
|
||||||
|
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
|
||||||
|
}
|
||||||
|
if src.PreferServerCipherSuites {
|
||||||
|
dst.PreferServerCipherSuites = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketsDisabled {
|
||||||
|
dst.SessionTicketsDisabled = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketKey != ([32]byte{}) {
|
||||||
|
dst.SessionTicketKey = src.SessionTicketKey
|
||||||
|
}
|
||||||
|
if src.ClientSessionCache != nil {
|
||||||
|
dst.ClientSessionCache = src.ClientSessionCache
|
||||||
|
}
|
||||||
|
if src.UnwrapSession != nil {
|
||||||
|
dst.UnwrapSession = src.UnwrapSession
|
||||||
|
}
|
||||||
|
if src.WrapSession != nil {
|
||||||
|
dst.WrapSession = src.WrapSession
|
||||||
|
}
|
||||||
|
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
|
||||||
|
dst.MinVersion = src.MinVersion
|
||||||
|
}
|
||||||
|
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
|
||||||
|
dst.MaxVersion = src.MaxVersion
|
||||||
|
}
|
||||||
|
if len(src.CurvePreferences) > 0 {
|
||||||
|
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
|
||||||
|
}
|
||||||
|
if src.DynamicRecordSizingDisabled {
|
||||||
|
dst.DynamicRecordSizingDisabled = true
|
||||||
|
}
|
||||||
|
if src.Renegotiation != 0 {
|
||||||
|
dst.Renegotiation = src.Renegotiation
|
||||||
|
}
|
||||||
|
if src.KeyLogWriter != nil {
|
||||||
|
dst.KeyLogWriter = src.KeyLogWriter
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloConfigList) > 0 {
|
||||||
|
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
|
||||||
|
}
|
||||||
|
if src.EncryptedClientHelloRejectionVerify != nil {
|
||||||
|
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
|
||||||
|
}
|
||||||
|
if src.GetEncryptedClientHelloKeys != nil {
|
||||||
|
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloKeys) > 0 {
|
||||||
|
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
|
||||||
|
}
|
||||||
|
}
|
||||||
237
internal/tlssniffercore/parser.go
Normal file
237
internal/tlssniffercore/parser.go
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
package tlssniffercore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientHelloMeta struct {
|
||||||
|
ServerName string
|
||||||
|
LocalAddr net.Addr
|
||||||
|
RemoteAddr net.Addr
|
||||||
|
SupportedProtos []string
|
||||||
|
SupportedVersions []uint16
|
||||||
|
CipherSuites []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type SniffResult struct {
|
||||||
|
IsTLS bool
|
||||||
|
ClientHello *ClientHelloMeta
|
||||||
|
Buffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type Sniffer struct{}
|
||||||
|
|
||||||
|
func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 64 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
||||||
|
meta, isTLS := sniffClientHello(limited, &buf, conn)
|
||||||
|
|
||||||
|
out := SniffResult{
|
||||||
|
IsTLS: isTLS,
|
||||||
|
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
|
||||||
|
}
|
||||||
|
if isTLS {
|
||||||
|
out.ClientHello = meta
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
|
||||||
|
meta := &ClientHelloMeta{
|
||||||
|
LocalAddr: conn.LocalAddr(),
|
||||||
|
RemoteAddr: conn.RemoteAddr(),
|
||||||
|
}
|
||||||
|
|
||||||
|
header, complete := readTLSRecordHeader(reader, buf)
|
||||||
|
if len(header) < 3 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
isTLS := header[0] == 0x16 && header[1] == 0x03
|
||||||
|
if !isTLS {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if len(header) < 5 || !complete {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||||
|
recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen)
|
||||||
|
if !bodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
|
||||||
|
helloBytes := append([]byte(nil), recordBody[4:]...)
|
||||||
|
for len(helloBytes) < helloLen {
|
||||||
|
nextHeader, ok := readTLSRecordHeader(reader, buf)
|
||||||
|
if len(nextHeader) < 5 || !ok {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
|
||||||
|
nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen)
|
||||||
|
if !bodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
helloBytes = append(helloBytes, nextBody...)
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloBody(meta, helloBytes[:helloLen])
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) {
|
||||||
|
return readBufferedBytes(reader, buf, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) {
|
||||||
|
if count <= 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
tmp := make([]byte, count)
|
||||||
|
readN, err := io.ReadFull(reader, tmp)
|
||||||
|
if readN > 0 {
|
||||||
|
buf.Write(tmp[:readN])
|
||||||
|
}
|
||||||
|
return append([]byte(nil), tmp[:readN]...), err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
|
||||||
|
if meta == nil || len(body) < 34 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 2 + 32
|
||||||
|
sessionIDLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+sessionIDLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += sessionIDLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+cipherSuitesLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for index := 0; index+1 < cipherSuitesLen; index += 2 {
|
||||||
|
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2]))
|
||||||
|
}
|
||||||
|
offset += cipherSuitesLen
|
||||||
|
|
||||||
|
if offset >= len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
compressionMethodsLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+compressionMethodsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += compressionMethodsLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+extensionsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
|
||||||
|
for offset := 0; offset+4 <= len(exts); {
|
||||||
|
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
|
||||||
|
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
|
||||||
|
offset += 4
|
||||||
|
if offset+extLen > len(exts) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extData := exts[offset : offset+extLen]
|
||||||
|
offset += extLen
|
||||||
|
|
||||||
|
switch extType {
|
||||||
|
case 0:
|
||||||
|
parseServerNameExtension(meta, extData)
|
||||||
|
case 16:
|
||||||
|
parseALPNExtension(meta, extData)
|
||||||
|
case 43:
|
||||||
|
parseSupportedVersionsExtension(meta, extData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset+3 <= len(list); {
|
||||||
|
nameType := list[offset]
|
||||||
|
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
|
||||||
|
offset += 3
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if nameType == 0 {
|
||||||
|
meta.ServerName = string(list[offset : offset+nameLen])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset < len(list); {
|
||||||
|
nameLen := int(list[offset])
|
||||||
|
offset++
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(data[0])
|
||||||
|
if listLen == 0 || 1+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[1 : 1+listLen]
|
||||||
|
for offset := 0; offset+1 < len(list); offset += 2 {
|
||||||
|
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
|
||||||
|
}
|
||||||
|
}
|
||||||
405
options.go
405
options.go
@ -1,405 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WithTimeout 设置请求总超时时间
|
|
||||||
// timeout > 0: 为本次请求注入 context 超时
|
|
||||||
// timeout = 0: 不额外设置请求总超时
|
|
||||||
// timeout < 0: 禁用 starnet 默认总超时
|
|
||||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Network.Timeout = timeout
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDialTimeout 设置连接超时时间
|
|
||||||
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Network.DialTimeout = timeout
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithProxy 设置代理
|
|
||||||
func WithProxy(proxy string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Network.Proxy = proxy
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDialFunc 设置自定义 Dial 函数
|
|
||||||
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Network.DialFunc = fn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTLSConfig 设置 TLS 配置
|
|
||||||
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.TLS.Config = tlsConfig
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithSkipTLSVerify 设置是否跳过 TLS 验证
|
|
||||||
func WithSkipTLSVerify(skip bool) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.TLS.SkipVerify = skip
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCustomIP 设置自定义 IP
|
|
||||||
func WithCustomIP(ips []string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for _, ip := range ips {
|
|
||||||
if net.ParseIP(ip) == nil {
|
|
||||||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.config.DNS.CustomIP = ips
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAddCustomIP 添加自定义 IP
|
|
||||||
func WithAddCustomIP(ip string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
if net.ParseIP(ip) == nil {
|
|
||||||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
|
||||||
}
|
|
||||||
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCustomDNS 设置自定义 DNS 服务器
|
|
||||||
func WithCustomDNS(dnsServers []string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for _, dns := range dnsServers {
|
|
||||||
if net.ParseIP(dns) == nil {
|
|
||||||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.config.DNS.CustomDNS = dnsServers
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAddCustomDNS 添加自定义 DNS 服务器
|
|
||||||
func WithAddCustomDNS(dns string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
if net.ParseIP(dns) == nil {
|
|
||||||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
|
||||||
}
|
|
||||||
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithLookupFunc 设置自定义 DNS 解析函数
|
|
||||||
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.DNS.LookupFunc = fn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithHeader 设置 Header
|
|
||||||
func WithHeader(key, value string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Headers.Set(key, value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithHeaders 批量设置 Headers
|
|
||||||
func WithHeaders(headers map[string]string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for k, v := range headers {
|
|
||||||
r.config.Headers.Set(k, v)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContentType 设置 Content-Type
|
|
||||||
func WithContentType(contentType string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Headers.Set("Content-Type", contentType)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithUserAgent 设置 User-Agent
|
|
||||||
func WithUserAgent(userAgent string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Headers.Set("User-Agent", userAgent)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBearerToken 设置 Bearer Token
|
|
||||||
func WithBearerToken(token string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Headers.Set("Authorization", "Bearer "+token)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBasicAuth 设置 Basic 认证
|
|
||||||
func WithBasicAuth(username, password string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.BasicAuth = [2]string{username, password}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCookie 添加 Cookie
|
|
||||||
func WithCookie(name, value, path string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: path,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithSimpleCookie 添加简单 Cookie(path 为 /)
|
|
||||||
func WithSimpleCookie(name, value string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCookies 批量添加 Cookies
|
|
||||||
func WithCookies(cookies map[string]string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for name, value := range cookies {
|
|
||||||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBody 设置请求体(字节)
|
|
||||||
func WithBody(body []byte) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Body.Bytes = body
|
|
||||||
r.config.Body.Reader = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBodyString 设置请求体(字符串)
|
|
||||||
func WithBodyString(body string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Body.Bytes = []byte(body)
|
|
||||||
r.config.Body.Reader = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBodyReader 设置请求体(Reader)
|
|
||||||
func WithBodyReader(reader io.Reader) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Body.Reader = reader
|
|
||||||
r.config.Body.Bytes = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithJSON 设置 JSON 请求体
|
|
||||||
func WithJSON(v interface{}) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
data, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "marshal json")
|
|
||||||
}
|
|
||||||
r.config.Headers.Set("Content-Type", ContentTypeJSON)
|
|
||||||
r.config.Body.Bytes = data
|
|
||||||
r.config.Body.Reader = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFormData 设置表单数据
|
|
||||||
func WithFormData(data map[string][]string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Body.FormData = data
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFormDataMap 设置表单数据(简化版)
|
|
||||||
func WithFormDataMap(data map[string]string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for k, v := range data {
|
|
||||||
r.config.Body.FormData[k] = []string{v}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAddFormData 添加表单数据
|
|
||||||
func WithAddFormData(key, value string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFile 添加文件
|
|
||||||
func WithFile(formName, filePath string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
stat, err := os.Stat(filePath)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(ErrFileNotFound, "file: %s", filePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|
||||||
FormName: formName,
|
|
||||||
FileName: stat.Name(),
|
|
||||||
FilePath: filePath,
|
|
||||||
FileSize: stat.Size(),
|
|
||||||
FileType: ContentTypeOctetStream,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithFileStream 添加文件流
|
|
||||||
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
if reader == nil {
|
|
||||||
return ErrNilReader
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|
||||||
FormName: formName,
|
|
||||||
FileName: fileName,
|
|
||||||
FileData: reader,
|
|
||||||
FileSize: size,
|
|
||||||
FileType: ContentTypeOctetStream,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithQuery 添加查询参数
|
|
||||||
func WithQuery(key, value string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Queries[key] = append(r.config.Queries[key], value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithQueries 批量添加查询参数
|
|
||||||
func WithQueries(queries map[string]string) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
for k, v := range queries {
|
|
||||||
r.config.Queries[k] = append(r.config.Queries[k], v)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContentLength 设置 Content-Length
|
|
||||||
func WithContentLength(length int64) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.ContentLength = length
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
|
|
||||||
func WithAutoCalcContentLength(auto bool) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.AutoCalcContentLength = auto
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithUploadProgress 设置文件上传进度回调
|
|
||||||
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.UploadProgress = fn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTransport 设置自定义 Transport
|
|
||||||
func WithTransport(transport *http.Transport) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.config.Transport = transport
|
|
||||||
r.config.CustomTransport = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAutoFetch 设置是否自动获取响应体
|
|
||||||
func WithAutoFetch(auto bool) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.autoFetch = auto
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
|
||||||
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
if maxBytes < 0 {
|
|
||||||
return fmt.Errorf("max response body bytes must be >= 0")
|
|
||||||
}
|
|
||||||
r.config.MaxRespBodyBytes = maxBytes
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithRawRequest 设置原始请求
|
|
||||||
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
if httpReq == nil {
|
|
||||||
return fmt.Errorf("httpReq cannot be nil")
|
|
||||||
}
|
|
||||||
r.httpReq = httpReq
|
|
||||||
r.doRaw = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithContext 设置 context
|
|
||||||
func WithContext(ctx context.Context) RequestOpt {
|
|
||||||
return func(r *Request) error {
|
|
||||||
r.ctx = ctx
|
|
||||||
r.httpReq = r.httpReq.WithContext(ctx)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
112
options_body.go
Normal file
112
options_body.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithBody 设置请求体(字节)
|
||||||
|
func WithBody(body []byte) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setBytesBodyConfig(&r.config.Body, body)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBodyString 设置请求体(字符串)
|
||||||
|
func WithBodyString(body string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setBytesBodyConfig(&r.config.Body, []byte(body))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBodyReader 设置请求体(Reader)。
|
||||||
|
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
|
||||||
|
func WithBodyReader(reader io.Reader) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setReaderBodyConfig(&r.config.Body, reader)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithJSON 设置 JSON 请求体
|
||||||
|
func WithJSON(v interface{}) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "marshal json")
|
||||||
|
}
|
||||||
|
r.config.Headers.Set("Content-Type", ContentTypeJSON)
|
||||||
|
setBytesBodyConfig(&r.config.Body, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFormData 设置表单数据
|
||||||
|
func WithFormData(data map[string][]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setFormBodyConfig(&r.config.Body, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFormDataMap 设置表单数据(简化版)
|
||||||
|
func WithFormDataMap(data map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setFormBodyConfig(&r.config.Body, nil)
|
||||||
|
for key, value := range data {
|
||||||
|
r.config.Body.FormData[key] = []string{value}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddFormData 添加表单数据
|
||||||
|
func WithAddFormData(key, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
ensureFormMode(&r.config.Body)
|
||||||
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFile 添加文件
|
||||||
|
func WithFile(formName, filePath string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
stat, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: stat.Name(),
|
||||||
|
FilePath: filePath,
|
||||||
|
FileSize: stat.Size(),
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFileStream 添加文件流
|
||||||
|
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if reader == nil {
|
||||||
|
return ErrNilReader
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: fileName,
|
||||||
|
FileData: reader,
|
||||||
|
FileSize: size,
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
132
options_config.go
Normal file
132
options_config.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithTimeout 设置请求总超时时间
|
||||||
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
|
// timeout = 0: 不额外设置请求总超时
|
||||||
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
|
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialTimeout 设置连接超时时间
|
||||||
|
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateDialTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithProxy 设置代理
|
||||||
|
func WithProxy(proxy string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateProxy(proxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialFunc 设置自定义 Dial 函数
|
||||||
|
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateDialFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig 设置 TLS 配置
|
||||||
|
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTLSConfig(tlsConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSServerName 设置显式 TLS ServerName/SNI。
|
||||||
|
func WithTLSServerName(serverName string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTLSServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTraceHooks 设置请求 trace 回调。
|
||||||
|
func WithTraceHooks(hooks *TraceHooks) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTraceHooks(hooks))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSkipTLSVerify 设置是否跳过 TLS 验证
|
||||||
|
func WithSkipTLSVerify(skip bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateSkipTLSVerify(skip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomIP 设置自定义 IP
|
||||||
|
func WithCustomIP(ips []string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateCustomIP(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddCustomIP 添加自定义 IP
|
||||||
|
func WithAddCustomIP(ip string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddCustomIP(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomDNS 设置自定义 DNS 服务器
|
||||||
|
func WithCustomDNS(dnsServers []string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateCustomDNS(dnsServers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddCustomDNS 添加自定义 DNS 服务器
|
||||||
|
func WithAddCustomDNS(dns string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddCustomDNS(dns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLookupFunc 设置自定义 DNS 解析函数
|
||||||
|
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateLookupFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBasicAuth 设置 Basic 认证
|
||||||
|
func WithBasicAuth(username, password string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateBasicAuth(username, password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithQuery 添加查询参数
|
||||||
|
func WithQuery(key, value string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithQueries 批量添加查询参数
|
||||||
|
func WithQueries(queries map[string]string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContentLength 设置 Content-Length
|
||||||
|
func WithContentLength(length int64) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateContentLength(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||||
|
func WithAutoCalcContentLength(auto bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAutoCalcContentLength(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUploadProgress 设置文件上传进度回调
|
||||||
|
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateUploadProgress(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTransport 设置自定义 Transport
|
||||||
|
func WithTransport(transport *http.Transport) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTransport(transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAutoFetch 设置是否自动获取响应体
|
||||||
|
func WithAutoFetch(auto bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAutoFetch(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRawRequest 设置原始请求
|
||||||
|
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateRawRequest(httpReq))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext 设置 context
|
||||||
|
func WithContext(ctx context.Context) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateContext(ctx))
|
||||||
|
}
|
||||||
99
options_header.go
Normal file
99
options_header.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// WithHeader 设置 Header
|
||||||
|
func WithHeader(key, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.config.Headers.Set(key, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHost 设置显式 Host 头覆盖。
|
||||||
|
func WithHost(host string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setRequestHostConfig(r.config, host)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaders 批量设置 Headers
|
||||||
|
func WithHeaders(headers map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for key, value := range headers {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.config.Headers.Set(key, value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContentType 设置 Content-Type
|
||||||
|
func WithContentType(contentType string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("Content-Type", contentType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUserAgent 设置 User-Agent
|
||||||
|
func WithUserAgent(userAgent string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("User-Agent", userAgent)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBearerToken 设置 Bearer Token
|
||||||
|
func WithBearerToken(token string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("Authorization", "Bearer "+token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCookie 添加 Cookie
|
||||||
|
func WithCookie(name, value, path string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: path,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSimpleCookie 添加简单 Cookie(path 为 /)
|
||||||
|
func WithSimpleCookie(name, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCookies 批量添加 Cookies
|
||||||
|
func WithCookies(cookies map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for name, value := range cookies {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
221
ping.go
221
ping.go
@ -2,14 +2,14 @@ package starnet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet/internal/pingcore"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -18,7 +18,6 @@ const (
|
|||||||
icmpTypeEchoRequestV6 = 128
|
icmpTypeEchoRequestV6 = 128
|
||||||
icmpTypeEchoReplyV6 = 129
|
icmpTypeEchoReplyV6 = 129
|
||||||
|
|
||||||
icmpHeaderLen = 8
|
|
||||||
icmpReadBufSz = 1500
|
icmpReadBufSz = 1500
|
||||||
|
|
||||||
defaultPingAttemptTimeout = 2 * time.Second
|
defaultPingAttemptTimeout = 2 * time.Second
|
||||||
@ -26,13 +25,7 @@ const (
|
|||||||
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
|
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICMP struct {
|
type ICMP = pingcore.ICMP
|
||||||
Type uint8
|
|
||||||
Code uint8
|
|
||||||
CheckSum uint16
|
|
||||||
Identifier uint16
|
|
||||||
SequenceNum uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type pingSocketSpec struct {
|
type pingSocketSpec struct {
|
||||||
network string
|
network string
|
||||||
@ -42,53 +35,20 @@ type pingSocketSpec struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PingOptions controls ping probing behavior.
|
// PingOptions controls ping probing behavior.
|
||||||
type PingOptions struct {
|
type PingOptions = pingcore.Options
|
||||||
Count int // ping attempts for Pingable, default 3
|
|
||||||
Timeout time.Duration // per-attempt timeout, default 2s
|
|
||||||
Interval time.Duration // delay between attempts, default 0
|
|
||||||
Deadline time.Time // overall deadline for Pingable/PingWithContext
|
|
||||||
PreferIPv4 bool // prefer IPv4 targets
|
|
||||||
PreferIPv6 bool // prefer IPv6 targets
|
|
||||||
SourceIP net.IP // optional source IP for raw socket bind
|
|
||||||
PayloadSize int // ICMP payload bytes, default 0
|
|
||||||
}
|
|
||||||
|
|
||||||
type PingResult struct {
|
type PingResult = pingcore.Result
|
||||||
Duration time.Duration
|
|
||||||
RecvCount int
|
|
||||||
RemoteIP string
|
|
||||||
}
|
|
||||||
|
|
||||||
var pingIdentifierSeed uint32
|
|
||||||
|
|
||||||
func nextPingIdentifier() uint16 {
|
func nextPingIdentifier() uint16 {
|
||||||
pid := uint32(os.Getpid() & 0xffff)
|
return pingcore.NextIdentifier()
|
||||||
n := atomic.AddUint32(&pingIdentifierSeed, 1)
|
|
||||||
return uint16((pid + n) & 0xffff)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func pingPayload(size int) []byte {
|
func pingPayload(size int) []byte {
|
||||||
if size <= 0 {
|
return pingcore.Payload(size)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
payload := make([]byte, size)
|
|
||||||
for i := 0; i < len(payload); i++ {
|
|
||||||
payload[i] = byte(i)
|
|
||||||
}
|
|
||||||
return payload
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
||||||
icmp := ICMP{
|
return pingcore.BuildICMP(seq, identifier, typ, payload)
|
||||||
Type: typ,
|
|
||||||
Code: 0,
|
|
||||||
CheckSum: 0,
|
|
||||||
Identifier: identifier,
|
|
||||||
SequenceNum: seq,
|
|
||||||
}
|
|
||||||
buf := marshalICMPPacket(icmp, payload)
|
|
||||||
icmp.CheckSum = checkSum(buf)
|
|
||||||
return icmp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
|
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
|
||||||
@ -120,8 +80,8 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
|
|||||||
return res, wrapError(err, "ping write request")
|
return res, wrapError(err, "ping write request")
|
||||||
}
|
}
|
||||||
|
|
||||||
tStart := time.Now()
|
startedAt := time.Now()
|
||||||
deadline := tStart.Add(timeout)
|
deadline := startedAt.Add(timeout)
|
||||||
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
|
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
|
||||||
deadline = d
|
deadline = d
|
||||||
}
|
}
|
||||||
@ -150,108 +110,34 @@ func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *n
|
|||||||
}
|
}
|
||||||
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
|
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
|
||||||
res.RecvCount = n
|
res.RecvCount = n
|
||||||
res.Duration = time.Since(tStart)
|
res.Duration = time.Since(startedAt)
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSum(data []byte) uint16 {
|
func checkSum(data []byte) uint16 {
|
||||||
var (
|
return pingcore.Checksum(data)
|
||||||
sum uint32
|
|
||||||
length int = len(data)
|
|
||||||
index int
|
|
||||||
)
|
|
||||||
for length > 1 {
|
|
||||||
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
|
||||||
index += 2
|
|
||||||
length -= 2
|
|
||||||
}
|
|
||||||
if length > 0 {
|
|
||||||
sum += uint32(data[index]) << 8
|
|
||||||
}
|
|
||||||
for sum>>16 != 0 {
|
|
||||||
sum = (sum & 0xffff) + (sum >> 16)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint16(^sum)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalICMP(icmp ICMP) []byte {
|
func marshalICMP(icmp ICMP) []byte {
|
||||||
return marshalICMPPacket(icmp, nil)
|
return pingcore.Marshal(icmp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
|
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
|
||||||
buf := make([]byte, icmpHeaderLen+len(payload))
|
return pingcore.MarshalPacket(icmp, payload)
|
||||||
buf[0] = icmp.Type
|
|
||||||
buf[1] = icmp.Code
|
|
||||||
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
|
|
||||||
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
|
|
||||||
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
|
|
||||||
copy(buf[icmpHeaderLen:], payload)
|
|
||||||
return buf
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
||||||
for _, off := range candidateICMPOffsets(packet, family) {
|
return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
|
||||||
if off < 0 || off+icmpHeaderLen > len(packet) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if packet[off] != expectedType || packet[off+1] != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func candidateICMPOffsets(packet []byte, family int) []int {
|
func candidateICMPOffsets(packet []byte, family int) []int {
|
||||||
offsets := []int{0}
|
return pingcore.CandidateICMPOffsets(packet, family)
|
||||||
if len(packet) == 0 {
|
|
||||||
return offsets
|
|
||||||
}
|
|
||||||
|
|
||||||
ver := packet[0] >> 4
|
|
||||||
if ver == 4 && len(packet) >= 20 {
|
|
||||||
ihl := int(packet[0]&0x0f) * 4
|
|
||||||
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
|
|
||||||
offsets = append(offsets, ihl)
|
|
||||||
}
|
|
||||||
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
|
|
||||||
offsets = append(offsets, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
|
|
||||||
if family == 4 && len(packet) >= 20+icmpHeaderLen {
|
|
||||||
offsets = append(offsets, 20)
|
|
||||||
}
|
|
||||||
if family == 6 && len(packet) >= 40+icmpHeaderLen {
|
|
||||||
offsets = append(offsets, 40)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dedupOffsets(offsets)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func dedupOffsets(offsets []int) []int {
|
func dedupOffsets(offsets []int) []int {
|
||||||
if len(offsets) <= 1 {
|
return pingcore.DedupOffsets(offsets)
|
||||||
return offsets
|
|
||||||
}
|
|
||||||
m := make(map[int]struct{}, len(offsets))
|
|
||||||
out := make([]int, 0, len(offsets))
|
|
||||||
for _, off := range offsets {
|
|
||||||
if _, ok := m[off]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m[off] = struct{}{}
|
|
||||||
out = append(out, off)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
|
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
|
||||||
@ -297,70 +183,18 @@ func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
||||||
if parsed := net.ParseIP(host); parsed != nil {
|
targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
|
||||||
return []*net.IPAddr{{IP: parsed}}, nil
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if len(targets) == 0 {
|
||||||
var targets []*net.IPAddr
|
return nil, ErrPingNoResolvedTarget
|
||||||
var err4 error
|
|
||||||
var err6 error
|
|
||||||
|
|
||||||
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
|
|
||||||
targets = append(targets, ip4)
|
|
||||||
} else {
|
|
||||||
err4 = e
|
|
||||||
}
|
}
|
||||||
|
return targets, nil
|
||||||
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
|
|
||||||
targets = append(targets, ip6)
|
|
||||||
} else {
|
|
||||||
err6 = e
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targets) > 0 {
|
|
||||||
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err4 != nil {
|
|
||||||
return nil, err4
|
|
||||||
}
|
|
||||||
if err6 != nil {
|
|
||||||
return nil, err6
|
|
||||||
}
|
|
||||||
return nil, ErrPingNoResolvedTarget
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
||||||
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
|
return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
|
||||||
return targets
|
|
||||||
}
|
|
||||||
|
|
||||||
ordered := make([]*net.IPAddr, 0, len(targets))
|
|
||||||
if preferIPv4 {
|
|
||||||
for _, t := range targets {
|
|
||||||
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
|
||||||
ordered = append(ordered, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, t := range targets {
|
|
||||||
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
|
||||||
ordered = append(ordered, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ordered
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range targets {
|
|
||||||
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
|
||||||
ordered = append(ordered, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, t := range targets {
|
|
||||||
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
|
||||||
ordered = append(ordered, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ordered
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizePingDialError(err error) error {
|
func normalizePingDialError(err error) error {
|
||||||
@ -450,7 +284,6 @@ func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOpt
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 权限问题通常与地址族无关,继续重试意义不大。
|
|
||||||
if errors.Is(err, ErrPingPermissionDenied) {
|
if errors.Is(err, ErrPingPermissionDenied) {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
@ -501,8 +334,8 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for i := 0; i < cfg.Count; i++ {
|
for index := 0; index < cfg.Count; index++ {
|
||||||
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
|
_, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@ -512,7 +345,7 @@ func Pingable(host string, opts *PingOptions) (bool, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if i < cfg.Count-1 && cfg.Interval > 0 {
|
if index < cfg.Count-1 && cfg.Interval > 0 {
|
||||||
timer := time.NewTimer(cfg.Interval)
|
timer := time.NewTimer(cfg.Interval)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
110
proxy_custom_ip_test.go
Normal file
110
proxy_custom_ip_test.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
|
||||||
|
tlsReqInfo := make(chan struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}, 1)
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tlsReqInfo <- struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}{
|
||||||
|
host: r.Host,
|
||||||
|
sni: r.TLS.ServerName,
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyServer := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
targetHost := "proxy-custom-ip.test"
|
||||||
|
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
|
||||||
|
req := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetProxy(proxyServer.URL).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
targets := proxyServer.Targets()
|
||||||
|
if len(targets) != 1 {
|
||||||
|
t.Fatalf("connect targets=%v; want 1 target", targets)
|
||||||
|
}
|
||||||
|
gotConnectTarget := targets[0]
|
||||||
|
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
|
||||||
|
if gotConnectTarget != wantConnectTarget {
|
||||||
|
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotTLS := <-tlsReqInfo
|
||||||
|
wantHost := net.JoinHostPort(targetHost, port)
|
||||||
|
if gotTLS.host != wantHost {
|
||||||
|
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
|
||||||
|
}
|
||||||
|
if gotTLS.sni != targetHost {
|
||||||
|
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
|
||||||
|
tlsReqInfo := make(chan struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}, 1)
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tlsReqInfo <- struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}{
|
||||||
|
host: r.Host,
|
||||||
|
sni: r.TLS.ServerName,
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetHost := "custom-ip-direct.test"
|
||||||
|
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
|
||||||
|
req := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
gotTLS := <-tlsReqInfo
|
||||||
|
wantHost := net.JoinHostPort(targetHost, port)
|
||||||
|
if gotTLS.host != wantHost {
|
||||||
|
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
|
||||||
|
}
|
||||||
|
if gotTLS.sni != targetHost {
|
||||||
|
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
331
proxy_local_helpers_test.go
Normal file
331
proxy_local_helpers_test.go
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connectProxyServer struct {
|
||||||
|
*httptest.Server
|
||||||
|
mu sync.Mutex
|
||||||
|
targets []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.Listener = listener
|
||||||
|
server.Start()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.Listener = listener
|
||||||
|
server.StartTLS()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
testT, ok := t.(*testing.T)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("newTrustedIPv4TLSServer requires *testing.T")
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName)
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("X509KeyPair: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(certPEM) {
|
||||||
|
t.Fatal("AppendCertsFromPEM returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
server.Listener = listener
|
||||||
|
server.TLS = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
server.StartTLS()
|
||||||
|
return server, pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split host port: %v", err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://%s:%s", host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
proxy := &connectProxyServer{}
|
||||||
|
if dialTarget == nil {
|
||||||
|
dialTarget = func(target string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
http.Error(w, "connect required", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.mu.Lock()
|
||||||
|
proxy.targets = append(proxy.targets, r.Host)
|
||||||
|
proxy.mu.Unlock()
|
||||||
|
|
||||||
|
targetConn, err := dialTarget(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatal("proxy response writer is not a hijacker")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, rw, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("hijack proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("write connect response: %v", err)
|
||||||
|
}
|
||||||
|
if err := rw.Flush(); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("flush connect response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayProxyConns(clientConn, targetConn)
|
||||||
|
}))
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *connectProxyServer) Targets() []string {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return append([]string(nil), p.targets...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type socks5ProxyServer struct {
|
||||||
|
ln net.Listener
|
||||||
|
addr string
|
||||||
|
dial func(target string) (net.Conn, error)
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mu sync.Mutex
|
||||||
|
targets []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if dialTarget == nil {
|
||||||
|
dialTarget = func(target string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4 socks5: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &socks5ProxyServer{
|
||||||
|
ln: ln,
|
||||||
|
addr: ln.Addr().String(),
|
||||||
|
dial: dialTarget,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer proxy.wg.Done()
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-proxy.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.wg.Add(1)
|
||||||
|
go func(c net.Conn) {
|
||||||
|
defer proxy.wg.Done()
|
||||||
|
proxy.handleConn(t, c)
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) URL() string {
|
||||||
|
return "socks5://" + p.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) Targets() []string {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return append([]string(nil), p.targets...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) Close() {
|
||||||
|
close(p.stopCh)
|
||||||
|
_ = p.ln.Close()
|
||||||
|
p.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
closeConn := true
|
||||||
|
defer func() {
|
||||||
|
if closeConn {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
header := make([]byte, 2)
|
||||||
|
if _, err := io.ReadFull(conn, header); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header[0] != 0x05 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
methods := make([]byte, int(header[1]))
|
||||||
|
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := conn.Write([]byte{0x05, 0x00}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqHeader := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host, err := readSOCKS5Addr(conn, reqHeader[3])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
portBytes := make([]byte, 2)
|
||||||
|
if _, err := io.ReadFull(conn, portBytes); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes)))
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.targets = append(p.targets, target)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
targetConn, err := p.dial(target)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
closeConn = false
|
||||||
|
relayProxyConns(conn, targetConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) {
|
||||||
|
switch atyp {
|
||||||
|
case 0x01:
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return net.IP(buf).String(), nil
|
||||||
|
case 0x03:
|
||||||
|
var size [1]byte
|
||||||
|
if _, err := io.ReadFull(r, size[:]); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
buf := make([]byte, int(size[0]))
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(buf), nil
|
||||||
|
case 0x04:
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return net.IP(buf).String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported atyp: %d", atyp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayProxyConns(left, right net.Conn) {
|
||||||
|
var once sync.Once
|
||||||
|
closeBoth := func() {
|
||||||
|
_ = left.Close()
|
||||||
|
_ = right.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(left, right)
|
||||||
|
once.Do(closeBoth)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(right, left)
|
||||||
|
once.Do(closeBoth)
|
||||||
|
}()
|
||||||
|
}
|
||||||
208
request.go
208
request.go
@ -22,14 +22,131 @@ type Request struct {
|
|||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
httpReq *http.Request
|
httpReq *http.Request
|
||||||
retry *retryPolicy
|
retry *retryPolicy
|
||||||
|
traceHooks *TraceHooks
|
||||||
|
traceState *traceState
|
||||||
|
|
||||||
applied bool // 是否已应用配置
|
applied bool // 是否已应用配置
|
||||||
doRaw bool // 是否使用原始请求(不修改)
|
doRaw bool // 是否使用原始请求(不修改)
|
||||||
autoFetch bool // 是否自动获取响应体
|
autoFetch bool // 是否自动获取响应体
|
||||||
|
|
||||||
|
rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request
|
||||||
|
rawTemplate *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeContext(ctx context.Context) context.Context {
|
||||||
|
if ctx != nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) {
|
||||||
|
if httpReq == nil {
|
||||||
|
return nil, fmt.Errorf("http request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := httpReq.Clone(normalizeContext(ctx))
|
||||||
|
switch {
|
||||||
|
case httpReq.Body == nil || httpReq.Body == http.NoBody:
|
||||||
|
cloned.Body = httpReq.Body
|
||||||
|
case httpReq.GetBody != nil:
|
||||||
|
body, err := httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return cloned, wrapError(err, "clone raw request body")
|
||||||
|
}
|
||||||
|
cloned.Body = body
|
||||||
|
default:
|
||||||
|
return cloned, fmt.Errorf("cannot clone raw request with non-replayable body")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rawBaseRequest() *http.Request {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.rawTemplate != nil {
|
||||||
|
return r.rawTemplate
|
||||||
|
}
|
||||||
|
return r.httpReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) invalidatePreparedState() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
r.execCtx = nil
|
||||||
|
r.traceState = nil
|
||||||
|
r.httpClient = nil
|
||||||
|
|
||||||
|
wasApplied := r.applied
|
||||||
|
r.applied = false
|
||||||
|
if !wasApplied || r.doRaw {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil {
|
||||||
|
r.err = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildPreparedRequestBase() error {
|
||||||
|
if r == nil || r.doRaw {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ctx := r.ctx
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "rebuild http request")
|
||||||
|
}
|
||||||
|
r.httpReq = httpReq
|
||||||
|
r.syncRequestHost()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildRawRequestBase() error {
|
||||||
|
if r == nil || !r.doRaw {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
baseReq := r.rawBaseRequest()
|
||||||
|
rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx))
|
||||||
|
if err != nil && baseReq != nil && baseReq == r.httpReq {
|
||||||
|
r.httpReq = baseReq.WithContext(normalizeContext(r.ctx))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rawReq != nil {
|
||||||
|
r.httpReq = rawReq
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildExecutionRequestBase() error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
r.execCtx = nil
|
||||||
|
r.traceState = nil
|
||||||
|
r.applied = false
|
||||||
|
if r.doRaw {
|
||||||
|
return r.rebuildRawRequestBase()
|
||||||
|
}
|
||||||
|
return r.rebuildPreparedRequestBase()
|
||||||
}
|
}
|
||||||
|
|
||||||
// newRequest 创建新请求(内部使用)
|
// newRequest 创建新请求(内部使用)
|
||||||
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
|
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
if method == "" {
|
if method == "" {
|
||||||
method = http.MethodGet
|
method = http.MethodGet
|
||||||
}
|
}
|
||||||
@ -133,6 +250,7 @@ func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
|||||||
|
|
||||||
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
||||||
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
req, err := newRequest(ctx, url, method, opts...)
|
req, err := newRequest(ctx, url, method, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &Request{
|
return &Request{
|
||||||
@ -163,16 +281,24 @@ func (r *Request) Clone() *Request {
|
|||||||
client: r.client,
|
client: r.client,
|
||||||
httpClient: r.httpClient,
|
httpClient: r.httpClient,
|
||||||
retry: cloneRetryPolicy(r.retry),
|
retry: cloneRetryPolicy(r.retry),
|
||||||
|
traceHooks: r.traceHooks,
|
||||||
applied: false, // 重置应用状态
|
applied: false, // 重置应用状态
|
||||||
doRaw: r.doRaw,
|
doRaw: r.doRaw,
|
||||||
autoFetch: r.autoFetch,
|
autoFetch: r.autoFetch,
|
||||||
|
|
||||||
|
rawSourceExternal: r.rawSourceExternal,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 重新创建 http.Request
|
// 重新创建 http.Request
|
||||||
if !r.doRaw {
|
if !r.doRaw {
|
||||||
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
|
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
|
||||||
} else {
|
} else {
|
||||||
cloned.httpReq = r.httpReq
|
rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx)
|
||||||
|
cloned.rawTemplate = rawTemplate
|
||||||
|
cloned.httpReq = rawTemplate
|
||||||
|
if err != nil && cloned.err == nil {
|
||||||
|
cloned.err = err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cloned
|
return cloned
|
||||||
@ -190,12 +316,7 @@ func (r *Request) Context() context.Context {
|
|||||||
|
|
||||||
// SetContext 设置 context
|
// SetContext 设置 context
|
||||||
func (r *Request) SetContext(ctx context.Context) *Request {
|
func (r *Request) SetContext(ctx context.Context) *Request {
|
||||||
if r.err != nil {
|
return r.applyMutation(mutateContext(ctx))
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.ctx = ctx
|
|
||||||
r.httpReq = r.httpReq.WithContext(ctx)
|
|
||||||
return r
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Method 获取 HTTP 方法
|
// Method 获取 HTTP 方法
|
||||||
@ -215,7 +336,13 @@ func (r *Request) SetMethod(method string) *Request {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.method = method
|
r.method = method
|
||||||
r.httpReq.Method = method
|
if r.httpReq != nil {
|
||||||
|
r.httpReq.Method = method
|
||||||
|
}
|
||||||
|
if r.doRaw && r.rawTemplate != nil {
|
||||||
|
r.rawTemplate.Method = method
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,45 +370,74 @@ func (r *Request) SetURL(urlStr string) *Request {
|
|||||||
|
|
||||||
r.url = urlStr
|
r.url = urlStr
|
||||||
u.Host = removeEmptyPort(u.Host)
|
u.Host = removeEmptyPort(u.Host)
|
||||||
r.httpReq.Host = u.Host
|
|
||||||
r.httpReq.URL = u
|
r.httpReq.URL = u
|
||||||
|
r.syncRequestHost()
|
||||||
// 更新 TLS ServerName
|
r.invalidatePreparedState()
|
||||||
if r.config.TLS.Config != nil {
|
|
||||||
r.config.TLS.Config.ServerName = u.Hostname()
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Request) effectiveRequestHost() string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if r.config != nil && r.config.Host != "" {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
|
if r.httpReq != nil && r.httpReq.URL != nil {
|
||||||
|
return removeEmptyPort(r.httpReq.URL.Host)
|
||||||
|
}
|
||||||
|
if r.url == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(r.url)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return removeEmptyPort(u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) syncRequestHost() {
|
||||||
|
if r == nil || r.httpReq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.httpReq.Host = r.effectiveRequestHost()
|
||||||
|
}
|
||||||
|
|
||||||
// RawRequest 获取底层 http.Request
|
// RawRequest 获取底层 http.Request
|
||||||
func (r *Request) RawRequest() *http.Request {
|
func (r *Request) RawRequest() *http.Request {
|
||||||
|
if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied {
|
||||||
|
return r.rawTemplate
|
||||||
|
}
|
||||||
return r.httpReq
|
return r.httpReq
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRawRequest 设置底层 http.Request(启用原始模式)
|
// SetRawRequest 设置底层 http.Request(启用原始模式)
|
||||||
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
|
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
|
||||||
if r.err != nil {
|
return r.applyMutation(mutateRawRequest(httpReq))
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.httpReq = httpReq
|
|
||||||
r.doRaw = true
|
|
||||||
if httpReq == nil {
|
|
||||||
r.err = fmt.Errorf("httpReq cannot be nil")
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableRawMode 启用原始模式(不修改请求)
|
// EnableRawMode 启用原始模式(不修改请求)
|
||||||
func (r *Request) EnableRawMode() *Request {
|
func (r *Request) EnableRawMode() *Request {
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
r.doRaw = true
|
r.doRaw = true
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisableRawMode 禁用原始模式
|
// DisableRawMode 禁用原始模式
|
||||||
func (r *Request) DisableRawMode() *Request {
|
func (r *Request) DisableRawMode() *Request {
|
||||||
|
if !r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.rawSourceExternal {
|
||||||
|
r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest")
|
||||||
|
return r
|
||||||
|
}
|
||||||
r.doRaw = false
|
r.doRaw = false
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,6 +485,10 @@ func (r *Request) Do() (*Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request) doOnce() (*Response, error) {
|
func (r *Request) doOnce() (*Response, error) {
|
||||||
|
if err := r.rebuildExecutionRequestBase(); err != nil {
|
||||||
|
return nil, wrapError(err, "rebuild execution request")
|
||||||
|
}
|
||||||
|
|
||||||
// 准备请求
|
// 准备请求
|
||||||
if err := r.prepare(); err != nil {
|
if err := r.prepare(); err != nil {
|
||||||
return nil, wrapError(err, "prepare request")
|
return nil, wrapError(err, "prepare request")
|
||||||
|
|||||||
276
request_body.go
276
request_body.go
@ -1,16 +1,9 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetBody 设置请求体(字节)
|
// SetBody 设置请求体(字节)
|
||||||
@ -21,12 +14,13 @@ func (r *Request) SetBody(body []byte) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Body.Bytes = body
|
setBytesBodyConfig(&r.config.Body, body)
|
||||||
r.config.Body.Reader = nil
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBodyReader 设置请求体(Reader)
|
// SetBodyReader 设置请求体(Reader)。
|
||||||
|
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
|
||||||
func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return r
|
return r
|
||||||
@ -34,8 +28,8 @@ func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Body.Reader = reader
|
setReaderBodyConfig(&r.config.Body, reader)
|
||||||
r.config.Body.Bytes = nil
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,7 +61,8 @@ func (r *Request) SetFormData(data map[string][]string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Body.FormData = cloneStringMapSlice(data)
|
setFormBodyConfig(&r.config.Body, data)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,7 +74,9 @@ func (r *Request) AddFormData(key, value string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
ensureFormMode(&r.config.Body)
|
||||||
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,9 +88,11 @@ func (r *Request) AddFormDataMap(data map[string]string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
for k, v := range data {
|
ensureFormMode(&r.config.Body)
|
||||||
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
|
for key, value := range data {
|
||||||
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
}
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +108,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
FormName: formName,
|
FormName: formName,
|
||||||
FileName: stat.Name(),
|
FileName: stat.Name(),
|
||||||
@ -116,6 +116,7 @@ func (r *Request) AddFile(formName, filePath string) *Request {
|
|||||||
FileSize: stat.Size(),
|
FileSize: stat.Size(),
|
||||||
FileType: ContentTypeOctetStream,
|
FileType: ContentTypeOctetStream,
|
||||||
})
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -132,6 +133,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
FormName: formName,
|
FormName: formName,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
@ -139,6 +141,7 @@ func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request
|
|||||||
FileSize: stat.Size(),
|
FileSize: stat.Size(),
|
||||||
FileType: ContentTypeOctetStream,
|
FileType: ContentTypeOctetStream,
|
||||||
})
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -155,6 +158,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
FormName: formName,
|
FormName: formName,
|
||||||
FileName: stat.Name(),
|
FileName: stat.Name(),
|
||||||
@ -162,6 +166,7 @@ func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request
|
|||||||
FileSize: stat.Size(),
|
FileSize: stat.Size(),
|
||||||
FileType: fileType,
|
FileType: fileType,
|
||||||
})
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -177,6 +182,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
FormName: formName,
|
FormName: formName,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
@ -184,6 +190,7 @@ func (r *Request) AddFileStream(formName, fileName string, size int64, reader io
|
|||||||
FileSize: size,
|
FileSize: size,
|
||||||
FileType: ContentTypeOctetStream,
|
FileType: ContentTypeOctetStream,
|
||||||
})
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@ -199,6 +206,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
FormName: formName,
|
FormName: formName,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
@ -206,243 +214,7 @@ func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, siz
|
|||||||
FileSize: size,
|
FileSize: size,
|
||||||
FileType: fileType,
|
FileType: fileType,
|
||||||
})
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyBody 应用请求体
|
|
||||||
func (r *Request) applyBody() error {
|
|
||||||
// 优先级:Reader > Bytes > Files > FormData
|
|
||||||
|
|
||||||
// 1. Reader
|
|
||||||
if r.config.Body.Reader != nil {
|
|
||||||
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
|
||||||
|
|
||||||
// 尝试获取长度
|
|
||||||
switch v := r.config.Body.Reader.(type) {
|
|
||||||
case *bytes.Buffer:
|
|
||||||
r.httpReq.ContentLength = int64(v.Len())
|
|
||||||
case *bytes.Reader:
|
|
||||||
r.httpReq.ContentLength = int64(v.Len())
|
|
||||||
case *strings.Reader:
|
|
||||||
r.httpReq.ContentLength = int64(v.Len())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. Bytes
|
|
||||||
if len(r.config.Body.Bytes) > 0 {
|
|
||||||
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
|
|
||||||
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Files(multipart/form-data)
|
|
||||||
if len(r.config.Body.Files) > 0 {
|
|
||||||
return r.applyMultipartBody()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. FormData(application/x-www-form-urlencoded)
|
|
||||||
if len(r.config.Body.FormData) > 0 {
|
|
||||||
values := url.Values{}
|
|
||||||
for k, vs := range r.config.Body.FormData {
|
|
||||||
for _, v := range vs {
|
|
||||||
values.Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
encoded := values.Encode()
|
|
||||||
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
|
|
||||||
r.httpReq.ContentLength = int64(len(encoded))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyMultipartBody 应用 multipart 请求体
|
|
||||||
func (r *Request) applyMultipartBody() error {
|
|
||||||
pr, pw := io.Pipe()
|
|
||||||
writer := multipart.NewWriter(pw)
|
|
||||||
|
|
||||||
// 设置 Content-Type
|
|
||||||
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
|
|
||||||
r.httpReq.Body = pr
|
|
||||||
|
|
||||||
// 在 goroutine 中写入数据
|
|
||||||
go func() {
|
|
||||||
defer pw.Close()
|
|
||||||
defer writer.Close()
|
|
||||||
|
|
||||||
// 写入表单字段
|
|
||||||
for k, vs := range r.config.Body.FormData {
|
|
||||||
for _, v := range vs {
|
|
||||||
if err := writer.WriteField(k, v); err != nil {
|
|
||||||
pw.CloseWithError(wrapError(err, "write form field"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 写入文件
|
|
||||||
for _, file := range r.config.Body.Files {
|
|
||||||
if err := r.writeFile(writer, file); err != nil {
|
|
||||||
pw.CloseWithError(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeFile 写入文件到 multipart writer
|
|
||||||
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
|
|
||||||
// 创建文件字段
|
|
||||||
part, err := writer.CreateFormFile(file.FormName, file.FileName)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "create form file")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取文件数据源
|
|
||||||
var reader io.Reader
|
|
||||||
if file.FileData != nil {
|
|
||||||
reader = file.FileData
|
|
||||||
} else if file.FilePath != "" {
|
|
||||||
f, err := os.Open(file.FilePath)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "open file")
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
reader = f
|
|
||||||
} else {
|
|
||||||
return ErrNilReader
|
|
||||||
}
|
|
||||||
|
|
||||||
// 复制文件数据(带进度)
|
|
||||||
if r.config.UploadProgress != nil {
|
|
||||||
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
|
|
||||||
} else {
|
|
||||||
_, err = io.Copy(part, reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "copy file data")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare 准备请求(应用配置)
|
|
||||||
func (r *Request) prepare() error {
|
|
||||||
if r.applied {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 即使 raw 模式也要确保有 httpClient
|
|
||||||
if r.httpClient == nil {
|
|
||||||
var err error
|
|
||||||
r.httpClient, err = r.buildHTTPClient()
|
|
||||||
if err != nil {
|
|
||||||
return err // ← 失败时不设置 applied
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.httpReq == nil {
|
|
||||||
return fmt.Errorf("http request is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 原始模式不修改请求内容
|
|
||||||
if !r.doRaw {
|
|
||||||
// 应用查询参数
|
|
||||||
if len(r.config.Queries) > 0 {
|
|
||||||
q := r.httpReq.URL.Query()
|
|
||||||
for k, values := range r.config.Queries {
|
|
||||||
for _, v := range values {
|
|
||||||
q.Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.httpReq.URL.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Headers
|
|
||||||
for k, values := range r.config.Headers {
|
|
||||||
for _, v := range values {
|
|
||||||
r.httpReq.Header.Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Cookies
|
|
||||||
for _, cookie := range r.config.Cookies {
|
|
||||||
r.httpReq.AddCookie(cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Basic Auth
|
|
||||||
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
|
||||||
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用请求体
|
|
||||||
if err := r.applyBody(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Content-Length
|
|
||||||
if r.config.ContentLength > 0 {
|
|
||||||
r.httpReq.ContentLength = r.config.ContentLength
|
|
||||||
} else if r.config.ContentLength < 0 {
|
|
||||||
r.httpReq.ContentLength = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 自动计算 Content-Length
|
|
||||||
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
|
||||||
data, err := io.ReadAll(r.httpReq.Body)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "read body for content length")
|
|
||||||
}
|
|
||||||
r.httpReq.ContentLength = int64(len(data))
|
|
||||||
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置 TLS ServerName(如果有 TLS Config)
|
|
||||||
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
|
|
||||||
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
execCtx := r.ctx
|
|
||||||
if !r.doRaw {
|
|
||||||
// raw 模式下不注入请求级网络配置,只应用 context/超时。
|
|
||||||
execCtx = injectRequestConfig(execCtx, r.config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
|
|
||||||
if r.config.Network.Timeout > 0 {
|
|
||||||
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.execCtx = execCtx
|
|
||||||
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
|
||||||
|
|
||||||
r.applied = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildHTTPClient 构建 HTTP Client
|
|
||||||
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
|
||||||
// 优先使用请求关联的 Client
|
|
||||||
if r.client != nil {
|
|
||||||
return r.client.HTTPClient(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 自定义 Transport
|
|
||||||
if r.config.CustomTransport && r.config.Transport != nil {
|
|
||||||
return &http.Client{
|
|
||||||
Transport: &Transport{base: r.config.Transport},
|
|
||||||
Timeout: 0,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 默认全局 client
|
|
||||||
return DefaultHTTPClient(), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,282 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetTimeout 设置请求总超时时间
|
|
||||||
// timeout > 0: 为本次请求注入 context 超时
|
|
||||||
// timeout = 0: 不额外设置请求总超时
|
|
||||||
// timeout < 0: 禁用 starnet 默认总超时
|
|
||||||
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Network.Timeout = timeout
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDialTimeout 设置连接超时时间
|
|
||||||
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Network.DialTimeout = timeout
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetProxy 设置代理
|
|
||||||
func (r *Request) SetProxy(proxy string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Network.Proxy = proxy
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDialFunc 设置自定义 Dial 函数
|
|
||||||
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Network.DialFunc = fn
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTLSConfig 设置 TLS 配置
|
|
||||||
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.TLS.Config = tlsConfig
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSkipTLSVerify 设置是否跳过 TLS 验证
|
|
||||||
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.TLS.SkipVerify = skip
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS)
|
|
||||||
func (r *Request) SetCustomIP(ips []string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 IP 格式
|
|
||||||
for _, ip := range ips {
|
|
||||||
if net.ParseIP(ip) == nil {
|
|
||||||
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.DNS.CustomIP = ips
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddCustomIP 添加自定义 IP
|
|
||||||
func (r *Request) AddCustomIP(ip string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
if net.ParseIP(ip) == nil {
|
|
||||||
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCustomDNS 设置自定义 DNS 服务器
|
|
||||||
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证 DNS 服务器格式
|
|
||||||
for _, dns := range dnsServers {
|
|
||||||
if net.ParseIP(dns) == nil {
|
|
||||||
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.DNS.CustomDNS = dnsServers
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddCustomDNS 添加自定义 DNS 服务器
|
|
||||||
func (r *Request) AddCustomDNS(dns string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
if net.ParseIP(dns) == nil {
|
|
||||||
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLookupFunc 设置自定义 DNS 解析函数
|
|
||||||
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.DNS.LookupFunc = fn
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBasicAuth 设置 Basic 认证
|
|
||||||
func (r *Request) SetBasicAuth(username, password string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.BasicAuth = [2]string{username, password}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetContentLength 设置 Content-Length
|
|
||||||
func (r *Request) SetContentLength(length int64) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.ContentLength = length
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
|
|
||||||
// 警告:启用后会将整个 body 读入内存
|
|
||||||
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.doRaw {
|
|
||||||
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
r.config.AutoCalcContentLength = auto
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTransport 设置自定义 Transport
|
|
||||||
func (r *Request) SetTransport(transport *http.Transport) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Transport = transport
|
|
||||||
r.config.CustomTransport = true
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUploadProgress 设置文件上传进度回调
|
|
||||||
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.UploadProgress = fn
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
|
||||||
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
if maxBytes < 0 {
|
|
||||||
r.err = fmt.Errorf("max response body bytes must be >= 0")
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.MaxRespBodyBytes = maxBytes
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddQuery 添加查询参数
|
|
||||||
func (r *Request) AddQuery(key, value string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Queries[key] = append(r.config.Queries[key], value)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetQuery 设置查询参数(覆盖)
|
|
||||||
func (r *Request) SetQuery(key, value string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Queries[key] = []string{value}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetQueries 设置所有查询参数(覆盖)
|
|
||||||
func (r *Request) SetQueries(queries map[string][]string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
r.config.Queries = cloneStringMapSlice(queries)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddQueries 批量添加查询参数
|
|
||||||
func (r *Request) AddQueries(queries map[string]string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
for k, v := range queries {
|
|
||||||
r.config.Queries[k] = append(r.config.Queries[k], v)
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteQuery 删除查询参数
|
|
||||||
func (r *Request) DeleteQuery(key string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
delete(r.config.Queries, key)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteQueryValue 删除查询参数的特定值
|
|
||||||
func (r *Request) DeleteQueryValue(key, value string) *Request {
|
|
||||||
if r.err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
values, ok := r.config.Queries[key]
|
|
||||||
if !ok {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
newValues := make([]string, 0, len(values))
|
|
||||||
for _, v := range values {
|
|
||||||
if v != value {
|
|
||||||
newValues = append(newValues, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(newValues) == 0 {
|
|
||||||
delete(r.config.Queries, key)
|
|
||||||
} else {
|
|
||||||
r.config.Queries[key] = newValues
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
34
request_execution.go
Normal file
34
request_execution.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// SetBasicAuth 设置 Basic 认证
|
||||||
|
func (r *Request) SetBasicAuth(username, password string) *Request {
|
||||||
|
return r.applyMutation(mutateBasicAuth(username, password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentLength 设置 Content-Length
|
||||||
|
func (r *Request) SetContentLength(length int64) *Request {
|
||||||
|
return r.applyMutation(mutateContentLength(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||||
|
// 警告:启用后会将整个 body 读入内存
|
||||||
|
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
|
||||||
|
return r.applyMutation(mutateAutoCalcContentLength(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTransport 设置自定义 Transport
|
||||||
|
func (r *Request) SetTransport(transport *http.Transport) *Request {
|
||||||
|
return r.applyMutation(mutateTransport(transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUploadProgress 设置文件上传进度回调
|
||||||
|
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
|
||||||
|
return r.applyMutation(mutateUploadProgress(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
|
||||||
|
return r.applyMutation(mutateMaxRespBodyBytes(maxBytes))
|
||||||
|
}
|
||||||
172
request_execution_regression_test.go
Normal file
172
request_execution_regression_test.go
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com/path", http.MethodPost).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
AddQuery("q", "v").
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if err := r.Context().Err(); err != nil {
|
||||||
|
t.Fatalf("request context already done: %v", err)
|
||||||
|
}
|
||||||
|
if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" {
|
||||||
|
t.Fatalf("header values=%v", values)
|
||||||
|
}
|
||||||
|
if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" {
|
||||||
|
t.Fatalf("query values=%v", values)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
if string(body) != "payload" {
|
||||||
|
t.Fatalf("body=%q", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
n := atomic.AddInt32(&attempts, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp1, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
if err := resp1.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp2, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&attempts); got != 2 {
|
||||||
|
t.Fatalf("attempts=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodGet).
|
||||||
|
SetRawRequest(rawReq).
|
||||||
|
SetProxy("http://proxy.example:8080").
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := req.execCtx.Value(ctxKeyRequestContext)
|
||||||
|
rc, ok := raw.(*RequestContext)
|
||||||
|
if !ok || rc == nil {
|
||||||
|
t.Fatalf("expected request context, got %#v", raw)
|
||||||
|
}
|
||||||
|
if rc.Proxy != "http://proxy.example:8080" {
|
||||||
|
t.Fatalf("proxy=%q", rc.Proxy)
|
||||||
|
}
|
||||||
|
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
|
||||||
|
t.Fatalf("custom ip=%v", rc.CustomIP)
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
|
||||||
|
t.Fatalf("tls config=%#v", rc.TLSConfig)
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "override.example" {
|
||||||
|
t.Fatalf("tls server name=%q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetFormDataOverridesBytesBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetBodyString("stale").
|
||||||
|
SetFormData(map[string][]string{"k": []string{"v"}})
|
||||||
|
|
||||||
|
if req.config.Body.Mode != bodyModeForm {
|
||||||
|
t.Fatalf("body mode=%v", req.config.Body.Mode)
|
||||||
|
}
|
||||||
|
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 {
|
||||||
|
t.Fatalf("unexpected stale body state: %#v", req.config.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := req.httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "k=v" {
|
||||||
|
t.Fatalf("body=%q; want k=v", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
filePath := filepath.Join(tmpDir, "payload.txt")
|
||||||
|
if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil {
|
||||||
|
t.Fatalf("WriteFile() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetJSON(map[string]string{"old": "json-only"}).
|
||||||
|
AddFile("file", filePath)
|
||||||
|
|
||||||
|
if req.config.Body.Mode != bodyModeMultipart {
|
||||||
|
t.Fatalf("body mode=%v", req.config.Body.Mode)
|
||||||
|
}
|
||||||
|
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil {
|
||||||
|
t.Fatalf("unexpected stale simple body state: %#v", req.config.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(req.httpReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(data), "file-body") {
|
||||||
|
t.Fatalf("multipart body missing file content: %q", string(data))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(data), "json-only") {
|
||||||
|
t.Fatalf("multipart body still contains stale json: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,6 +4,25 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isHostHeaderKey(key string) bool {
|
||||||
|
return http.CanonicalHeaderKey(key) == "Host"
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRequestHostConfig(config *RequestConfig, host string) {
|
||||||
|
if config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if config.Headers == nil {
|
||||||
|
config.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
config.Host = host
|
||||||
|
if host == "" {
|
||||||
|
config.Headers.Del("Host")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Headers.Set("Host", host)
|
||||||
|
}
|
||||||
|
|
||||||
// SetHeader 设置 Header(覆盖)
|
// SetHeader 设置 Header(覆盖)
|
||||||
func (r *Request) SetHeader(key, value string) *Request {
|
func (r *Request) SetHeader(key, value string) *Request {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
@ -12,7 +31,11 @@ func (r *Request) SetHeader(key, value string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.SetHost(value)
|
||||||
|
}
|
||||||
r.config.Headers.Set(key, value)
|
r.config.Headers.Set(key, value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +47,11 @@ func (r *Request) AddHeader(key, value string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.SetHost(value)
|
||||||
|
}
|
||||||
r.config.Headers.Add(key, value)
|
r.config.Headers.Add(key, value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +64,9 @@ func (r *Request) SetHeaders(headers http.Header) *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Headers = cloneHeader(headers)
|
r.config.Headers = cloneHeader(headers)
|
||||||
|
r.config.Host = r.config.Headers.Get("Host")
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,8 +79,14 @@ func (r *Request) AddHeaders(headers map[string]string) *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
for k, v := range headers {
|
for k, v := range headers {
|
||||||
|
if isHostHeaderKey(k) {
|
||||||
|
setRequestHostConfig(r.config, v)
|
||||||
|
continue
|
||||||
|
}
|
||||||
r.config.Headers.Add(k, v)
|
r.config.Headers.Add(k, v)
|
||||||
}
|
}
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,18 +98,56 @@ func (r *Request) DeleteHeader(key string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, "")
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
r.config.Headers.Del(key)
|
r.config.Headers.Del(key)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHeader 获取 Header
|
// GetHeader 获取 Header
|
||||||
func (r *Request) GetHeader(key string) string {
|
func (r *Request) GetHeader(key string) string {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
return r.config.Headers.Get(key)
|
return r.config.Headers.Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Headers 获取所有 Headers
|
// Headers 获取所有 Headers
|
||||||
func (r *Request) Headers() http.Header {
|
func (r *Request) Headers() http.Header {
|
||||||
return r.config.Headers
|
if r == nil || r.config == nil {
|
||||||
|
return make(http.Header)
|
||||||
|
}
|
||||||
|
return cloneHeader(r.config.Headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHost 设置请求 Host 头覆盖。
|
||||||
|
func (r *Request) SetHost(host string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
setRequestHostConfig(r.config, host)
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host 获取显式 Host 覆盖。
|
||||||
|
func (r *Request) Host() string {
|
||||||
|
if r.config != nil && r.config.Host != "" {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
|
if r.httpReq != nil {
|
||||||
|
return r.httpReq.Host
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetContentType 设置 Content-Type
|
// SetContentType 设置 Content-Type
|
||||||
@ -104,7 +178,8 @@ func (r *Request) AddCookie(cookie *http.Cookie) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Cookies = append(r.config.Cookies, cookie)
|
r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie))
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,7 +209,8 @@ func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Cookies = cookies
|
r.config.Cookies = cloneCookies(cookies)
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,12 +229,16 @@ func (r *Request) AddCookies(cookies map[string]string) *Request {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookies 获取所有 Cookies
|
// Cookies 获取所有 Cookies
|
||||||
func (r *Request) Cookies() []*http.Cookie {
|
func (r *Request) Cookies() []*http.Cookie {
|
||||||
return r.config.Cookies
|
if r == nil || r.config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cloneCookies(r.config.Cookies)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetHeaders 重置所有 Headers
|
// ResetHeaders 重置所有 Headers
|
||||||
@ -167,6 +247,9 @@ func (r *Request) ResetHeaders() *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Headers = make(http.Header)
|
r.config.Headers = make(http.Header)
|
||||||
|
r.config.Host = ""
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,5 +259,6 @@ func (r *Request) ResetCookies() *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Cookies = []*http.Cookie{}
|
r.config.Cookies = []*http.Cookie{}
|
||||||
|
r.invalidatePreparedState()
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
69
request_multipart.go
Normal file
69
request_multipart.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// applyMultipartBody 应用 multipart 请求体
|
||||||
|
func (r *Request) applyMultipartBody(execCtx context.Context) error {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
writer := multipart.NewWriter(pw)
|
||||||
|
|
||||||
|
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
r.httpReq.Body = pr
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer pw.Close()
|
||||||
|
defer writer.Close()
|
||||||
|
|
||||||
|
for key, values := range r.config.Body.FormData {
|
||||||
|
for _, value := range values {
|
||||||
|
if err := writer.WriteField(key, value); err != nil {
|
||||||
|
pw.CloseWithError(wrapError(err, "write form field"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range r.config.Body.Files {
|
||||||
|
if err := r.writeFile(execCtx, writer, file); err != nil {
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeFile 写入文件到 multipart writer
|
||||||
|
func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error {
|
||||||
|
part, err := writer.CreateFormFile(file.FormName, file.FileName)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "create form file")
|
||||||
|
}
|
||||||
|
|
||||||
|
var reader io.Reader
|
||||||
|
if file.FileData != nil {
|
||||||
|
reader = file.FileData
|
||||||
|
} else if file.FilePath != "" {
|
||||||
|
f, err := os.Open(file.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "open file")
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
reader = f
|
||||||
|
} else {
|
||||||
|
return ErrNilReader
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "copy file data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
326
request_mutation.go
Normal file
326
request_mutation.go
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestMutation func(*Request) error
|
||||||
|
|
||||||
|
func (r *Request) applyMutation(mutation requestMutation) *Request {
|
||||||
|
if r == nil || r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if err := mutation(r); err != nil {
|
||||||
|
r.err = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOptFromMutation(mutation requestMutation) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mutation(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomIPs(ips []string) error {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if net.ParseIP(ip) == nil {
|
||||||
|
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomDNS(dnsServers []string) error {
|
||||||
|
for _, dns := range dnsServers {
|
||||||
|
if net.ParseIP(dns) == nil {
|
||||||
|
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProxyURL(proxy string) (*url.URL, error) {
|
||||||
|
if proxy == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "parse proxy url")
|
||||||
|
}
|
||||||
|
if proxyURL.Scheme == "" {
|
||||||
|
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
|
||||||
|
}
|
||||||
|
if proxyURL.Host == "" {
|
||||||
|
return nil, fmt.Errorf("proxy host is required: %s", proxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTimeout(timeout time.Duration) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.Timeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDialTimeout(timeout time.Duration) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.DialTimeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateProxy(proxy string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if _, err := parseProxyURL(proxy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.Network.Proxy = proxy
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.DialFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.Config = tlsConfig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTLSServerName(serverName string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.ServerName = serverName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.traceHooks = hooks
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSkipTLSVerify(skip bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.SkipVerify = skip
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateCustomIP(ips []string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomIPs(ips); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomIP = cloneStringSlice(ips)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddCustomIP(ip string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomIPs([]string{ip}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateCustomDNS(dnsServers []string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomDNS(dnsServers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddCustomDNS(dns string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomDNS([]string{dns}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.DNS.LookupFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateBasicAuth(username, password string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.BasicAuth = [2]string{username, password}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateContentLength(length int64) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.ContentLength = length
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAutoCalcContentLength(auto bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if r.doRaw {
|
||||||
|
return fmt.Errorf("cannot set auto calc content length in raw mode")
|
||||||
|
}
|
||||||
|
r.config.AutoCalcContentLength = auto
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTransport(transport *http.Transport) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Transport = transport
|
||||||
|
r.config.CustomTransport = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.UploadProgress = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAutoFetch(auto bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.autoFetch = auto
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if maxBytes < 0 {
|
||||||
|
return fmt.Errorf("max response body bytes must be >= 0")
|
||||||
|
}
|
||||||
|
r.config.MaxRespBodyBytes = maxBytes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateContext(ctx context.Context) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
|
r.ctx = ctx
|
||||||
|
if r.doRaw && r.rawTemplate != nil {
|
||||||
|
r.rawTemplate = r.rawTemplate.WithContext(ctx)
|
||||||
|
}
|
||||||
|
if r.httpReq != nil {
|
||||||
|
r.httpReq = r.httpReq.WithContext(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateRawRequest(httpReq *http.Request) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if httpReq == nil {
|
||||||
|
return fmt.Errorf("httpReq cannot be nil")
|
||||||
|
}
|
||||||
|
r.httpReq = httpReq
|
||||||
|
r.rawTemplate = httpReq
|
||||||
|
r.ctx = normalizeContext(httpReq.Context())
|
||||||
|
r.method = httpReq.Method
|
||||||
|
if httpReq.URL != nil {
|
||||||
|
r.url = httpReq.URL.String()
|
||||||
|
}
|
||||||
|
r.doRaw = true
|
||||||
|
r.rawSourceExternal = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddQuery(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSetQuery(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries[key] = []string{value}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSetQueries(queries map[string][]string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries = cloneStringMapSlice(queries)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddQueries(queries map[string]string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for key, value := range queries {
|
||||||
|
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDeleteQuery(key string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
delete(r.config.Queries, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDeleteQueryValue(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
values, ok := r.config.Queries[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newValues := make([]string, 0, len(values))
|
||||||
|
for _, item := range values {
|
||||||
|
if item != value {
|
||||||
|
newValues = append(newValues, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newValues) == 0 {
|
||||||
|
delete(r.config.Queries, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.config.Queries[key] = newValues
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
71
request_network.go
Normal file
71
request_network.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetTimeout 设置请求总超时时间
|
||||||
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
|
// timeout = 0: 不额外设置请求总超时
|
||||||
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
|
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
||||||
|
return r.applyMutation(mutateTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialTimeout 设置连接超时时间
|
||||||
|
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
|
||||||
|
return r.applyMutation(mutateDialTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProxy 设置代理
|
||||||
|
func (r *Request) SetProxy(proxy string) *Request {
|
||||||
|
return r.applyMutation(mutateProxy(proxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialFunc 设置自定义 Dial 函数
|
||||||
|
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
|
||||||
|
return r.applyMutation(mutateDialFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTLSConfig 设置 TLS 配置
|
||||||
|
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
|
||||||
|
return r.applyMutation(mutateTLSConfig(tlsConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTLSServerName 设置显式 TLS ServerName/SNI。
|
||||||
|
func (r *Request) SetTLSServerName(serverName string) *Request {
|
||||||
|
return r.applyMutation(mutateTLSServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSkipTLSVerify 设置是否跳过 TLS 验证
|
||||||
|
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
|
||||||
|
return r.applyMutation(mutateSkipTLSVerify(skip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS)
|
||||||
|
func (r *Request) SetCustomIP(ips []string) *Request {
|
||||||
|
return r.applyMutation(mutateCustomIP(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCustomIP 添加自定义 IP
|
||||||
|
func (r *Request) AddCustomIP(ip string) *Request {
|
||||||
|
return r.applyMutation(mutateAddCustomIP(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomDNS 设置自定义 DNS 服务器
|
||||||
|
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
|
||||||
|
return r.applyMutation(mutateCustomDNS(dnsServers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCustomDNS 添加自定义 DNS 服务器
|
||||||
|
func (r *Request) AddCustomDNS(dns string) *Request {
|
||||||
|
return r.applyMutation(mutateAddCustomDNS(dns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLookupFunc 设置自定义 DNS 解析函数
|
||||||
|
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
|
||||||
|
return r.applyMutation(mutateLookupFunc(fn))
|
||||||
|
}
|
||||||
314
request_prepare.go
Normal file
314
request_prepare.go
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
|
||||||
|
if httpReq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
httpReq.Body = io.NopCloser(bytes.NewReader(data))
|
||||||
|
httpReq.ContentLength = int64(len(data))
|
||||||
|
httpReq.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader(data)), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearSimpleBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Bytes = nil
|
||||||
|
body.Reader = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetFormBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetMultipartBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Files = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBytesBodyConfig(body *BodyConfig, data []byte) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeBytes
|
||||||
|
body.Bytes = cloneBytes(data)
|
||||||
|
body.Reader = nil
|
||||||
|
resetFormBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeReader
|
||||||
|
body.Reader = reader
|
||||||
|
body.Bytes = nil
|
||||||
|
resetFormBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeForm
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
body.FormData = cloneStringMapSlice(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureFormMode(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
body.Mode = bodyModeForm
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureMultipartMode(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode == bodyModeMultipart {
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode != bodyModeForm {
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeMultipart
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
data := make([]byte, reader.Len())
|
||||||
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
data := make([]byte, reader.Len())
|
||||||
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyBody 应用请求体
|
||||||
|
func (r *Request) applyBody(execCtx context.Context) error {
|
||||||
|
r.httpReq.Body = nil
|
||||||
|
r.httpReq.GetBody = nil
|
||||||
|
r.httpReq.ContentLength = 0
|
||||||
|
|
||||||
|
switch r.config.Body.Mode {
|
||||||
|
case bodyModeReader:
|
||||||
|
if r.config.Body.Reader == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch reader := r.config.Body.Reader.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
|
||||||
|
case *bytes.Reader:
|
||||||
|
data, err := snapshotBytesReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "snapshot bytes reader")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
case *strings.Reader:
|
||||||
|
data, err := snapshotStringReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "snapshot strings reader")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
default:
|
||||||
|
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
||||||
|
}
|
||||||
|
switch reader := r.config.Body.Reader.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
case *bytes.Reader:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
case *strings.Reader:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case bodyModeBytes:
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
|
||||||
|
return nil
|
||||||
|
case bodyModeMultipart:
|
||||||
|
return r.applyMultipartBody(execCtx)
|
||||||
|
case bodyModeForm:
|
||||||
|
values := url.Values{}
|
||||||
|
for key, items := range r.config.Body.FormData {
|
||||||
|
for _, value := range items {
|
||||||
|
values.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded := values.Encode()
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare 准备请求(应用配置)
|
||||||
|
func (r *Request) prepare() (err error) {
|
||||||
|
if r.applied {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.httpReq == nil {
|
||||||
|
return fmt.Errorf("http request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
execCtx := r.ctx
|
||||||
|
if execCtx == nil {
|
||||||
|
execCtx = context.Background()
|
||||||
|
}
|
||||||
|
defaultTLSServerName := ""
|
||||||
|
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
|
||||||
|
defaultTLSServerName = r.httpReq.URL.Hostname()
|
||||||
|
}
|
||||||
|
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
|
||||||
|
|
||||||
|
var traceState *traceState
|
||||||
|
if r.traceHooks != nil {
|
||||||
|
traceState = newTraceState(r.traceHooks)
|
||||||
|
execCtx = withTraceState(execCtx, traceState)
|
||||||
|
if clientTrace := traceState.clientTrace(); clientTrace != nil {
|
||||||
|
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if r.config.Network.Timeout > 0 {
|
||||||
|
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil && cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if r.httpClient == nil {
|
||||||
|
r.httpClient, err = r.buildHTTPClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.doRaw {
|
||||||
|
if len(r.config.Queries) > 0 {
|
||||||
|
query := r.httpReq.URL.Query()
|
||||||
|
for key, values := range r.config.Queries {
|
||||||
|
for _, value := range values {
|
||||||
|
query.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.httpReq.URL.RawQuery = query.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, values := range r.config.Headers {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, value := range values {
|
||||||
|
r.httpReq.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cookie := range r.config.Cookies {
|
||||||
|
r.httpReq.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
||||||
|
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.applyBody(execCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.ContentLength > 0 {
|
||||||
|
r.httpReq.ContentLength = r.config.ContentLength
|
||||||
|
} else if r.config.ContentLength < 0 {
|
||||||
|
r.httpReq.ContentLength = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
||||||
|
data, err := io.ReadAll(r.httpReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "read body for content length")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.syncRequestHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.execCtx = execCtx
|
||||||
|
r.traceState = traceState
|
||||||
|
r.cancel = cancel
|
||||||
|
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
||||||
|
r.applied = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildHTTPClient 构建 HTTP Client
|
||||||
|
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
||||||
|
if r.client != nil {
|
||||||
|
return r.client.HTTPClient(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.CustomTransport && r.config.Transport != nil {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &Transport{base: r.config.Transport},
|
||||||
|
Timeout: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultHTTPClient(), nil
|
||||||
|
}
|
||||||
335
request_prepare_regression_test.go
Normal file
335
request_prepare_regression_test.go
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
SetBodyString("first")
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
if _, err := req.HTTPClient(); err != nil {
|
||||||
|
t.Fatalf("HTTPClient() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetHeader("X-Test", "two").SetBodyString("second")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().String() error: %v", err)
|
||||||
|
}
|
||||||
|
if body != "two:second" {
|
||||||
|
t.Fatalf("body=%q; want %q", body, "two:second")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusNoContent,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusNoContent,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return nil, r.Context().Err()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Close()
|
||||||
|
|
||||||
|
_, err = req.SetTimeout(10 * time.Millisecond).Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("second Do() succeeded; want timeout error")
|
||||||
|
}
|
||||||
|
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("second Do() error=%v; want timeout", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
writer := multipart.NewWriter(pw)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(io.Discard, pr)
|
||||||
|
_ = pr.Close()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := req.writeFile(ctx, writer, RequestFile{
|
||||||
|
FormName: "file",
|
||||||
|
FileName: "payload.txt",
|
||||||
|
FileData: strings.NewReader("payload"),
|
||||||
|
FileSize: int64(len("payload")),
|
||||||
|
})
|
||||||
|
_ = writer.Close()
|
||||||
|
_ = pw.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *Request
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bytes",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes-reader",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strings-reader",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "form-data",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
|
||||||
|
want: "k=v",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
if tt.req.httpReq.GetBody == nil {
|
||||||
|
t.Fatal("GetBody is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := tt.req.httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != tt.want {
|
||||||
|
t.Fatalf("body=%q; want %q", string(data), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type replayRoundTripper struct {
|
||||||
|
attempts int
|
||||||
|
bodies []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
|
||||||
|
rt.attempts++
|
||||||
|
rt.bodies = append(rt.bodies, string(body))
|
||||||
|
if rt.attempts == 1 {
|
||||||
|
return nil, errors.New("first target failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := &replayRoundTripper{}
|
||||||
|
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("roundTripResolvedTargets() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if len(rt.bodies) != 2 {
|
||||||
|
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
|
||||||
|
}
|
||||||
|
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
|
||||||
|
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := &replayRoundTripper{}
|
||||||
|
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
|
||||||
|
}
|
||||||
|
if len(rt.bodies) != 1 {
|
||||||
|
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
|
||||||
|
}
|
||||||
|
if rt.bodies[0] != "payload" {
|
||||||
|
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryReplayableReaderBody(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
||||||
|
SetBodyReader(strings.NewReader("payload")).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
if string(body) != "payload" {
|
||||||
|
t.Fatalf("body=%q; want payload", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("retry")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&attempts); got != 2 {
|
||||||
|
t.Fatalf("attempts=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithProxyInvalidReturnsError(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewRequest() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
|
||||||
|
_, err := NewClient(WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewClient() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
31
request_query.go
Normal file
31
request_query.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
// AddQuery 添加查询参数
|
||||||
|
func (r *Request) AddQuery(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateAddQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuery 设置查询参数(覆盖)
|
||||||
|
func (r *Request) SetQuery(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateSetQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQueries 设置所有查询参数(覆盖)
|
||||||
|
func (r *Request) SetQueries(queries map[string][]string) *Request {
|
||||||
|
return r.applyMutation(mutateSetQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQueries 批量添加查询参数
|
||||||
|
func (r *Request) AddQueries(queries map[string]string) *Request {
|
||||||
|
return r.applyMutation(mutateAddQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteQuery 删除查询参数
|
||||||
|
func (r *Request) DeleteQuery(key string) *Request {
|
||||||
|
return r.applyMutation(mutateDeleteQuery(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteQueryValue 删除查询参数的特定值
|
||||||
|
func (r *Request) DeleteQueryValue(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateDeleteQueryValue(key, value))
|
||||||
|
}
|
||||||
168
request_state_boundary_test.go
Normal file
168
request_state_boundary_test.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContextNilUsesBackground(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if r.Context() == nil {
|
||||||
|
t.Fatal("request context is nil")
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.SetContext(nil).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if req.Context() == nil {
|
||||||
|
t.Fatal("request Context() is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if r.Context() == nil {
|
||||||
|
t.Fatal("retry request context is nil")
|
||||||
|
}
|
||||||
|
if atomic.AddInt32(&hits, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("retry")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.
|
||||||
|
SetTimeout(DefaultTimeout).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&hits); got != 2 {
|
||||||
|
t.Fatalf("hits=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
rawReq.Header.Set("X-Test", "one")
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
|
||||||
|
cloned := req.Clone()
|
||||||
|
|
||||||
|
if cloned.Err() != nil {
|
||||||
|
t.Fatalf("Clone() err = %v", cloned.Err())
|
||||||
|
}
|
||||||
|
if cloned.RawRequest() == rawReq {
|
||||||
|
t.Fatal("raw request pointer reused")
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned.RawRequest().Header.Set("X-Test", "two")
|
||||||
|
if rawReq.Header.Get("X-Test") != "one" {
|
||||||
|
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := cloned.RawRequest().GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "payload" {
|
||||||
|
t.Fatalf("body=%q; want payload", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
|
||||||
|
rawReq := &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: mustParseURL(t, "http://example.com/upload"),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
|
||||||
|
cloned := req.Clone()
|
||||||
|
|
||||||
|
if cloned.Err() == nil {
|
||||||
|
t.Fatal("Clone() should fail for non-replayable raw body")
|
||||||
|
}
|
||||||
|
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
|
||||||
|
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("DisableRawMode() should set error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
|
||||||
|
t.Fatalf("DisableRawMode() err=%v", req.Err())
|
||||||
|
}
|
||||||
|
if !req.doRaw {
|
||||||
|
t.Fatal("request should remain in raw mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
|
t.Helper()
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("url.Parse() error: %v", err)
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
6
request_trace.go
Normal file
6
request_trace.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
// SetTraceHooks 设置请求 trace 回调。
|
||||||
|
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
|
||||||
|
return r.applyMutation(mutateTraceHooks(hooks))
|
||||||
|
}
|
||||||
85
retry.go
85
retry.go
@ -1,6 +1,7 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -9,6 +10,7 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,6 +89,9 @@ func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
|
|||||||
return policy, nil
|
return policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithRetry 为请求启用自动重试。
|
||||||
|
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||||||
|
// 以避免请求体已落地后再次发送。
|
||||||
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
||||||
return func(r *Request) error {
|
return func(r *Request) error {
|
||||||
policy, err := buildRetryPolicy(max, opts...)
|
policy, err := buildRetryPolicy(max, opts...)
|
||||||
@ -98,6 +103,9 @@ func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRetry 为请求启用自动重试。
|
||||||
|
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||||||
|
// 以避免请求体已落地后再次发送。
|
||||||
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return r
|
return r
|
||||||
@ -226,10 +234,10 @@ func (r *Request) doWithRetry() (*Response, error) {
|
|||||||
return r.doOnce()
|
return r.doOnce()
|
||||||
}
|
}
|
||||||
|
|
||||||
retryCtx := r.ctx
|
retryCtx := normalizeContext(r.ctx)
|
||||||
retryCancel := func() {}
|
retryCancel := func() {}
|
||||||
if r.config.Network.Timeout > 0 {
|
if r.config.Network.Timeout > 0 {
|
||||||
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
|
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
|
||||||
}
|
}
|
||||||
defer retryCancel()
|
defer retryCancel()
|
||||||
|
|
||||||
@ -238,6 +246,12 @@ func (r *Request) doWithRetry() (*Response, error) {
|
|||||||
var lastErr error
|
var lastErr error
|
||||||
|
|
||||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
attemptNo := attempt + 1
|
||||||
|
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
MaxAttempts: maxAttempts,
|
||||||
|
})
|
||||||
|
|
||||||
attemptReq, err := r.newRetryAttempt(retryCtx)
|
attemptReq, err := r.newRetryAttempt(retryCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err, "build retry attempt")
|
return nil, wrapError(err, "build retry attempt")
|
||||||
@ -248,7 +262,19 @@ func (r *Request) doWithRetry() (*Response, error) {
|
|||||||
resp.request = r
|
resp.request = r
|
||||||
}
|
}
|
||||||
|
|
||||||
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
|
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
|
||||||
|
statusCode := 0
|
||||||
|
if resp != nil {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
}
|
||||||
|
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
MaxAttempts: maxAttempts,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Err: err,
|
||||||
|
WillRetry: willRetry,
|
||||||
|
})
|
||||||
|
if !willRetry {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,6 +288,10 @@ func (r *Request) doWithRetry() (*Response, error) {
|
|||||||
if delay <= 0 {
|
if delay <= 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
Delay: delay,
|
||||||
|
})
|
||||||
|
|
||||||
timer := time.NewTimer(delay)
|
timer := time.NewTimer(delay)
|
||||||
select {
|
select {
|
||||||
@ -293,19 +323,9 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
|
|||||||
return attempt, nil
|
return attempt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.httpReq == nil {
|
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
|
||||||
return nil, fmt.Errorf("http request is nil")
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
|
|
||||||
raw := r.httpReq.Clone(ctx)
|
|
||||||
if r.httpReq.GetBody != nil {
|
|
||||||
body, err := r.httpReq.GetBody()
|
|
||||||
if err != nil {
|
|
||||||
return nil, wrapError(err, "get raw request body")
|
|
||||||
}
|
|
||||||
raw.Body = body
|
|
||||||
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
|
|
||||||
return nil, fmt.Errorf("raw request body is not replayable")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
attempt.httpReq = raw
|
attempt.httpReq = raw
|
||||||
@ -316,6 +336,9 @@ func (p *retryPolicy) canRetryRequest(r *Request) bool {
|
|||||||
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return isReplayableRequest(r)
|
return isReplayableRequest(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,20 +370,40 @@ func isReplayableRequest(r *Request) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reader / stream body 通常不可重放,保守地不重试。
|
return isReplayableConfiguredBody(r.config.Body)
|
||||||
if r.config.Body.Reader != nil {
|
}
|
||||||
|
|
||||||
|
func hasReaderRequestBody(r *Request) bool {
|
||||||
|
if r == nil || r.config == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, f := range r.config.Body.Files {
|
func isReplayableConfiguredBody(body BodyConfig) bool {
|
||||||
if f.FileData != nil || f.FilePath == "" {
|
switch body.Mode {
|
||||||
return false
|
case bodyModeReader:
|
||||||
|
return isReplayableBodyReader(body.Reader)
|
||||||
|
case bodyModeMultipart:
|
||||||
|
for _, file := range body.Files {
|
||||||
|
if file.FileData != nil || file.FilePath == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isReplayableBodyReader(reader io.Reader) bool {
|
||||||
|
switch reader.(type) {
|
||||||
|
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
||||||
if attempt >= maxAttempts-1 {
|
if attempt >= maxAttempts-1 {
|
||||||
return false
|
return false
|
||||||
|
|||||||
244
review_regression_test.go
Normal file
244
review_regression_test.go
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstTarget := net.JoinHostPort("127.0.0.2", port)
|
||||||
|
secondTarget := net.JoinHostPort("127.0.0.1", port)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
connectTargets []string
|
||||||
|
)
|
||||||
|
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
http.Error(w, "connect required", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
connectTargets = append(connectTargets, r.Host)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if r.Host == firstTarget {
|
||||||
|
http.Error(w, "first target failed", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetConn, err := net.Dial("tcp", r.Host)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatal("proxy response writer is not a hijacker")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, rw, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("hijack proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("write connect response: %v", err)
|
||||||
|
}
|
||||||
|
if err := rw.Flush(); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("flush connect response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayProxyConns(clientConn, targetConn)
|
||||||
|
}))
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
|
||||||
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetProxy(proxyServer.URL).
|
||||||
|
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(connectTargets) != 2 {
|
||||||
|
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
|
||||||
|
}
|
||||||
|
if connectTargets[0] != firstTarget {
|
||||||
|
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
|
||||||
|
}
|
||||||
|
if connectTargets[1] != secondTarget {
|
||||||
|
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
dnsStartCount int
|
||||||
|
dnsDoneCount int
|
||||||
|
lastHost string
|
||||||
|
)
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsStartCount++
|
||||||
|
lastHost = info.Host
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected dns error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
|
||||||
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if dnsStartCount != 1 {
|
||||||
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
||||||
|
}
|
||||||
|
if dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
||||||
|
}
|
||||||
|
if lastHost != "localhost" {
|
||||||
|
t.Fatalf("lastHost=%q; want localhost", lastHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestHeadersReturnsCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
SetHost("origin.example")
|
||||||
|
|
||||||
|
headers := req.Headers()
|
||||||
|
headers.Set("X-Test", "two")
|
||||||
|
headers.Set("Host", "mutated.example")
|
||||||
|
|
||||||
|
if got := req.GetHeader("X-Test"); got != "one" {
|
||||||
|
t.Fatalf("request header=%q; want one", got)
|
||||||
|
}
|
||||||
|
if got := req.Host(); got != "origin.example" {
|
||||||
|
t.Fatalf("request host=%q; want origin.example", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCookiesIsolation(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
source := []*http.Cookie{{
|
||||||
|
Name: "session",
|
||||||
|
Value: "one",
|
||||||
|
Path: "/",
|
||||||
|
}}
|
||||||
|
|
||||||
|
req.SetCookies(source)
|
||||||
|
source[0].Value = "mutated-outside"
|
||||||
|
|
||||||
|
got := req.Cookies()
|
||||||
|
if len(got) != 1 || got[0].Value != "one" {
|
||||||
|
t.Fatalf("cookies after SetCookies=%v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got[0].Value = "mutated-copy"
|
||||||
|
if latest := req.Cookies()[0].Value; latest != "one" {
|
||||||
|
t.Fatalf("internal cookie mutated via getter, got %q", latest)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := &http.Cookie{Name: "auth", Value: "token"}
|
||||||
|
req.ResetCookies().AddCookie(cookie)
|
||||||
|
cookie.Value = "changed"
|
||||||
|
|
||||||
|
latest := req.Cookies()
|
||||||
|
if len(latest) != 1 || latest[0].Value != "token" {
|
||||||
|
t.Fatalf("cookies after AddCookie=%v", latest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsStartCount int
|
||||||
|
var dnsDoneCount int
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
dnsStartCount++
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
dnsDoneCount++
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if dnsStartCount != 1 || dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
125
tls_test.go
125
tls_test.go
@ -104,7 +104,34 @@ func TestRequestLevelTLSOverride(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestTls(t *testing.T) {
|
func TestRequestTls(t *testing.T) {
|
||||||
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
|
var requestCount int
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestCount++
|
||||||
|
switch requestCount {
|
||||||
|
case 1:
|
||||||
|
if r.Header.Get("Hello") != "" {
|
||||||
|
t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello"))
|
||||||
|
}
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
t.Fatalf("unexpected authorization on first request: %q", auth)
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("Hello"); got != "world" {
|
||||||
|
t.Fatalf("hello header=%q; want world", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" {
|
||||||
|
t.Fatalf("authorization=%q; want bearer token", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
localURL := httpsURLForHost(t, server, "localhost")
|
||||||
|
resp, err := NewSimpleRequest(localURL, "GET").
|
||||||
|
SetTLSConfig(&tls.Config{RootCAs: pool}).
|
||||||
|
Do()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Do() error: %v", err)
|
t.Fatalf("Do() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -114,11 +141,13 @@ func TestRequestTls(t *testing.T) {
|
|||||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
t.Logf("Response: %v", resp.Body().MustString())
|
t.Logf("Response: %v", resp.Body().MustString())
|
||||||
|
|
||||||
client, err := NewClient()
|
client, err := NewClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClient() error: %v", err)
|
t.Fatalf("NewClient() error: %v", err)
|
||||||
}
|
}
|
||||||
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
|
resp, err = client.NewSimpleRequest(localURL, "GET",
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
WithHeader("hello", "world"),
|
WithHeader("hello", "world"),
|
||||||
WithContext(context.Background()),
|
WithContext(context.Background()),
|
||||||
WithBearerToken("ddddddd")).Do()
|
WithBearerToken("ddddddd")).Do()
|
||||||
@ -134,14 +163,24 @@ func TestRequestTls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSWithProxyPath(t *testing.T) {
|
func TestTLSWithProxyPath(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("proxied"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
client, err := NewClient()
|
client, err := NewClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
|
||||||
WithTimeout(10*time.Second),
|
WithTimeout(10*time.Second),
|
||||||
WithProxy("http://127.0.0.1:29992"),
|
WithProxy(proxy.URL),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -152,10 +191,22 @@ func TestTLSWithProxyPath(t *testing.T) {
|
|||||||
t.Fatalf("Do error: %v", err)
|
t.Fatalf("Do error: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Close()
|
defer resp.Close()
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 {
|
||||||
|
t.Fatalf("proxy targets=%v; want 1 target", targets)
|
||||||
|
}
|
||||||
t.Log(resp.Status)
|
t.Log(resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSWithProxyBug(t *testing.T) {
|
func TestTLSWithProxyBug(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
client, err := NewClient()
|
client, err := NewClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -163,9 +214,11 @@ func TestTLSWithProxyBug(t *testing.T) {
|
|||||||
|
|
||||||
// 关键:使用 WithProxy 触发 needsDynamicTransport
|
// 关键:使用 WithProxy 触发 needsDynamicTransport
|
||||||
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
|
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
|
||||||
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET",
|
||||||
WithTimeout(10*time.Second),
|
WithTimeout(10*time.Second),
|
||||||
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
|
WithProxy(proxy.URL),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -177,20 +230,30 @@ func TestTLSWithProxyBug(t *testing.T) {
|
|||||||
t.Fatalf("Do error: %v", err)
|
t.Fatalf("Do error: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Close()
|
defer resp.Close()
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
|
||||||
|
t.Fatalf("proxy targets=%v", targets)
|
||||||
|
}
|
||||||
t.Logf("Status: %s", resp.Status)
|
t.Logf("Status: %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更精准的复现:直接测试有问题的分支
|
// 更精准的复现:直接测试有问题的分支
|
||||||
func TestTLSDialWithoutServerName(t *testing.T) {
|
func TestTLSDialWithoutServerName(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
client, err := NewClient()
|
client, err := NewClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
|
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
|
||||||
req, err := client.NewRequest("https://www.google.com", "GET",
|
req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET",
|
||||||
WithTimeout(10*time.Second),
|
WithTimeout(10*time.Second),
|
||||||
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -206,14 +269,21 @@ func TestTLSDialWithoutServerName(t *testing.T) {
|
|||||||
|
|
||||||
// 最小复现:只要触发 needsDynamicTransport 即可
|
// 最小复现:只要触发 needsDynamicTransport 即可
|
||||||
func TestMinimalTLSBug(t *testing.T) {
|
func TestMinimalTLSBug(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
client, err := NewClient()
|
client, err := NewClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDialTimeout 也会触发动态 transport
|
// WithDialTimeout 也会触发动态 transport
|
||||||
req, err := client.NewRequest("https://www.baidu.com", "GET",
|
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
|
||||||
WithDialTimeout(5*time.Second),
|
WithDialTimeout(5*time.Second),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -227,3 +297,40 @@ func TestMinimalTLSBug(t *testing.T) {
|
|||||||
defer resp.Close()
|
defer resp.Close()
|
||||||
t.Logf("Status: %s", resp.Status)
|
t.Logf("Status: %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTLSWithSOCKS5ProxyPath(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newSOCKS5ProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy(proxy.URL()),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
|
||||||
|
t.Fatalf("socks5 targets=%v", targets)
|
||||||
|
}
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|||||||
360
tlssniffer.go
360
tlssniffer.go
@ -4,12 +4,13 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet/internal/tlssniffercore"
|
||||||
)
|
)
|
||||||
|
|
||||||
// replayConn replays buffered bytes first, then reads from live conn.
|
// replayConn replays buffered bytes first, then reads from live conn.
|
||||||
@ -51,214 +52,35 @@ type TLSSniffer struct{}
|
|||||||
|
|
||||||
// Sniff detects TLS and extracts SNI when possible.
|
// Sniff detects TLS and extracts SNI when possible.
|
||||||
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
||||||
if maxBytes <= 0 {
|
res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
|
||||||
maxBytes = 64 * 1024
|
if err != nil {
|
||||||
|
return SniffResult{}, err
|
||||||
}
|
}
|
||||||
|
return convertCoreSniffResult(res), nil
|
||||||
|
}
|
||||||
|
|
||||||
var buf bytes.Buffer
|
func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
|
||||||
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
|
||||||
meta, isTLS := sniffClientHello(limited, &buf, conn)
|
|
||||||
|
|
||||||
out := SniffResult{
|
out := SniffResult{
|
||||||
IsTLS: isTLS,
|
IsTLS: res.IsTLS,
|
||||||
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
|
Buffer: res.Buffer,
|
||||||
}
|
}
|
||||||
if isTLS {
|
if res.ClientHello != nil {
|
||||||
out.ClientHello = meta
|
out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
|
func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
|
||||||
meta := &ClientHelloMeta{
|
if meta == nil {
|
||||||
LocalAddr: conn.LocalAddr(),
|
return nil
|
||||||
RemoteAddr: conn.RemoteAddr(),
|
|
||||||
}
|
}
|
||||||
|
return &ClientHelloMeta{
|
||||||
header, complete := readTLSRecordHeader(r, buf)
|
ServerName: meta.ServerName,
|
||||||
if len(header) < 3 {
|
LocalAddr: meta.LocalAddr,
|
||||||
return nil, false
|
RemoteAddr: meta.RemoteAddr,
|
||||||
}
|
SupportedProtos: append([]string(nil), meta.SupportedProtos...),
|
||||||
isTLS := header[0] == 0x16 && header[1] == 0x03
|
SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
|
||||||
if !isTLS {
|
CipherSuites: append([]uint16(nil), meta.CipherSuites...),
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
if len(header) < 5 || !complete {
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
|
|
||||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
|
||||||
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
|
|
||||||
if !bodyOK {
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
|
|
||||||
helloBytes := append([]byte(nil), recordBody[4:]...)
|
|
||||||
for len(helloBytes) < helloLen {
|
|
||||||
nextHeader, nextOK := readTLSRecordHeader(r, buf)
|
|
||||||
if len(nextHeader) < 5 || !nextOK {
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
|
|
||||||
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
|
|
||||||
if !nextBodyOK {
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
helloBytes = append(helloBytes, nextBody...)
|
|
||||||
}
|
|
||||||
|
|
||||||
parseClientHelloBody(meta, helloBytes[:helloLen])
|
|
||||||
return meta, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
|
|
||||||
return readBufferedBytes(r, buf, 5)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
|
|
||||||
if n <= 0 {
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
tmp := make([]byte, n)
|
|
||||||
readN, err := io.ReadFull(r, tmp)
|
|
||||||
if readN > 0 {
|
|
||||||
buf.Write(tmp[:readN])
|
|
||||||
}
|
|
||||||
return append([]byte(nil), tmp[:readN]...), err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
|
|
||||||
if meta == nil || len(body) < 34 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := 2 + 32
|
|
||||||
sessionIDLen := int(body[offset])
|
|
||||||
offset++
|
|
||||||
if offset+sessionIDLen > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
offset += sessionIDLen
|
|
||||||
|
|
||||||
if offset+2 > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
|
||||||
offset += 2
|
|
||||||
if offset+cipherSuitesLen > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i := 0; i+1 < cipherSuitesLen; i += 2 {
|
|
||||||
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
|
|
||||||
}
|
|
||||||
offset += cipherSuitesLen
|
|
||||||
|
|
||||||
if offset >= len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
compressionMethodsLen := int(body[offset])
|
|
||||||
offset++
|
|
||||||
if offset+compressionMethodsLen > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
offset += compressionMethodsLen
|
|
||||||
|
|
||||||
if offset+2 > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
|
||||||
offset += 2
|
|
||||||
if offset+extensionsLen > len(body) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
|
|
||||||
for offset := 0; offset+4 <= len(exts); {
|
|
||||||
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
|
|
||||||
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
|
|
||||||
offset += 4
|
|
||||||
if offset+extLen > len(exts) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
extData := exts[offset : offset+extLen]
|
|
||||||
offset += extLen
|
|
||||||
|
|
||||||
switch extType {
|
|
||||||
case 0:
|
|
||||||
parseServerNameExtension(meta, extData)
|
|
||||||
case 16:
|
|
||||||
parseALPNExtension(meta, extData)
|
|
||||||
case 43:
|
|
||||||
parseSupportedVersionsExtension(meta, extData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
|
|
||||||
if len(data) < 2 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
|
||||||
if listLen == 0 || 2+listLen > len(data) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
list := data[2 : 2+listLen]
|
|
||||||
for offset := 0; offset+3 <= len(list); {
|
|
||||||
nameType := list[offset]
|
|
||||||
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
|
|
||||||
offset += 3
|
|
||||||
if offset+nameLen > len(list) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if nameType == 0 {
|
|
||||||
meta.ServerName = string(list[offset : offset+nameLen])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
offset += nameLen
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
|
|
||||||
if len(data) < 2 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
|
||||||
if listLen == 0 || 2+listLen > len(data) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
list := data[2 : 2+listLen]
|
|
||||||
for offset := 0; offset < len(list); {
|
|
||||||
nameLen := int(list[offset])
|
|
||||||
offset++
|
|
||||||
if offset+nameLen > len(list) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
|
|
||||||
offset += nameLen
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
|
|
||||||
if len(data) < 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
listLen := int(data[0])
|
|
||||||
if listLen == 0 || 1+listLen > len(data) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
list := data[1 : 1+listLen]
|
|
||||||
for offset := 0; offset+1 < len(list); offset += 2 {
|
|
||||||
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -290,17 +112,17 @@ type Conn struct {
|
|||||||
|
|
||||||
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
||||||
return &Conn{
|
return &Conn{
|
||||||
Conn: raw,
|
Conn: raw,
|
||||||
plainConn: raw,
|
plainConn: raw,
|
||||||
baseTLSConfig: cfg.BaseTLSConfig,
|
baseTLSConfig: cfg.BaseTLSConfig,
|
||||||
getConfigForClient: cfg.GetConfigForClient,
|
getConfigForClient: cfg.GetConfigForClient,
|
||||||
getConfigForClientHello: cfg.GetConfigForClientHello,
|
getConfigForClientHello: cfg.GetConfigForClientHello,
|
||||||
allowNonTLS: cfg.AllowNonTLS,
|
allowNonTLS: cfg.AllowNonTLS,
|
||||||
sniffer: TLSSniffer{},
|
sniffer: TLSSniffer{},
|
||||||
sniffTimeout: cfg.SniffTimeout,
|
sniffTimeout: cfg.SniffTimeout,
|
||||||
maxClientHello: cfg.MaxClientHelloBytes,
|
maxClientHello: cfg.MaxClientHelloBytes,
|
||||||
logger: cfg.Logger,
|
logger: cfg.Logger,
|
||||||
stats: stats,
|
stats: stats,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,123 +255,11 @@ func (c *Conn) serverName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
||||||
if base == nil {
|
return tlssniffercore.ComposeServerTLSConfig(base, selected)
|
||||||
return selected
|
|
||||||
}
|
|
||||||
if selected == nil {
|
|
||||||
return base
|
|
||||||
}
|
|
||||||
|
|
||||||
out := base.Clone()
|
|
||||||
applyServerTLSOverrides(out, selected)
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyServerTLSOverrides(dst, src *tls.Config) {
|
func applyServerTLSOverrides(dst, src *tls.Config) {
|
||||||
if dst == nil || src == nil {
|
tlssniffercore.ApplyServerTLSOverrides(dst, src)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if src.Rand != nil {
|
|
||||||
dst.Rand = src.Rand
|
|
||||||
}
|
|
||||||
if src.Time != nil {
|
|
||||||
dst.Time = src.Time
|
|
||||||
}
|
|
||||||
if len(src.Certificates) > 0 {
|
|
||||||
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
|
|
||||||
}
|
|
||||||
if len(src.NameToCertificate) > 0 {
|
|
||||||
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
|
|
||||||
for k, v := range src.NameToCertificate {
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
dst.NameToCertificate = m
|
|
||||||
}
|
|
||||||
if src.GetCertificate != nil {
|
|
||||||
dst.GetCertificate = src.GetCertificate
|
|
||||||
}
|
|
||||||
if src.GetClientCertificate != nil {
|
|
||||||
dst.GetClientCertificate = src.GetClientCertificate
|
|
||||||
}
|
|
||||||
if src.GetConfigForClient != nil {
|
|
||||||
dst.GetConfigForClient = src.GetConfigForClient
|
|
||||||
}
|
|
||||||
if src.VerifyPeerCertificate != nil {
|
|
||||||
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
|
|
||||||
}
|
|
||||||
if src.VerifyConnection != nil {
|
|
||||||
dst.VerifyConnection = src.VerifyConnection
|
|
||||||
}
|
|
||||||
if src.RootCAs != nil {
|
|
||||||
dst.RootCAs = src.RootCAs
|
|
||||||
}
|
|
||||||
if len(src.NextProtos) > 0 {
|
|
||||||
dst.NextProtos = append([]string(nil), src.NextProtos...)
|
|
||||||
}
|
|
||||||
if src.ServerName != "" {
|
|
||||||
dst.ServerName = src.ServerName
|
|
||||||
}
|
|
||||||
if src.ClientAuth > dst.ClientAuth {
|
|
||||||
dst.ClientAuth = src.ClientAuth
|
|
||||||
}
|
|
||||||
if src.ClientCAs != nil {
|
|
||||||
dst.ClientCAs = src.ClientCAs
|
|
||||||
}
|
|
||||||
if src.InsecureSkipVerify {
|
|
||||||
dst.InsecureSkipVerify = true
|
|
||||||
}
|
|
||||||
if len(src.CipherSuites) > 0 {
|
|
||||||
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
|
|
||||||
}
|
|
||||||
if src.PreferServerCipherSuites {
|
|
||||||
dst.PreferServerCipherSuites = true
|
|
||||||
}
|
|
||||||
if src.SessionTicketsDisabled {
|
|
||||||
dst.SessionTicketsDisabled = true
|
|
||||||
}
|
|
||||||
if src.SessionTicketKey != ([32]byte{}) {
|
|
||||||
dst.SessionTicketKey = src.SessionTicketKey
|
|
||||||
}
|
|
||||||
if src.ClientSessionCache != nil {
|
|
||||||
dst.ClientSessionCache = src.ClientSessionCache
|
|
||||||
}
|
|
||||||
if src.UnwrapSession != nil {
|
|
||||||
dst.UnwrapSession = src.UnwrapSession
|
|
||||||
}
|
|
||||||
if src.WrapSession != nil {
|
|
||||||
dst.WrapSession = src.WrapSession
|
|
||||||
}
|
|
||||||
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
|
|
||||||
dst.MinVersion = src.MinVersion
|
|
||||||
}
|
|
||||||
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
|
|
||||||
dst.MaxVersion = src.MaxVersion
|
|
||||||
}
|
|
||||||
if len(src.CurvePreferences) > 0 {
|
|
||||||
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
|
|
||||||
}
|
|
||||||
if src.DynamicRecordSizingDisabled {
|
|
||||||
dst.DynamicRecordSizingDisabled = true
|
|
||||||
}
|
|
||||||
if src.Renegotiation != 0 {
|
|
||||||
dst.Renegotiation = src.Renegotiation
|
|
||||||
}
|
|
||||||
if src.KeyLogWriter != nil {
|
|
||||||
dst.KeyLogWriter = src.KeyLogWriter
|
|
||||||
}
|
|
||||||
if len(src.EncryptedClientHelloConfigList) > 0 {
|
|
||||||
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
|
|
||||||
}
|
|
||||||
if src.EncryptedClientHelloRejectionVerify != nil {
|
|
||||||
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
|
|
||||||
}
|
|
||||||
if src.GetEncryptedClientHelloKeys != nil {
|
|
||||||
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
|
|
||||||
}
|
|
||||||
if len(src.EncryptedClientHelloKeys) > 0 {
|
|
||||||
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) IsTLS() bool {
|
func (c *Conn) IsTLS() bool {
|
||||||
|
|||||||
340
trace.go
Normal file
340
trace.go
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type traceContextKey struct{}
|
||||||
|
|
||||||
|
// TraceHooks defines optional callbacks for network lifecycle events.
|
||||||
|
// Hooks may be called concurrently.
|
||||||
|
type TraceHooks struct {
|
||||||
|
GetConn func(TraceGetConnInfo)
|
||||||
|
GotConn func(TraceGotConnInfo)
|
||||||
|
PutIdleConn func(TracePutIdleConnInfo)
|
||||||
|
DNSStart func(TraceDNSStartInfo)
|
||||||
|
DNSDone func(TraceDNSDoneInfo)
|
||||||
|
ConnectStart func(TraceConnectStartInfo)
|
||||||
|
ConnectDone func(TraceConnectDoneInfo)
|
||||||
|
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
|
||||||
|
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
|
||||||
|
WroteHeaderField func(TraceWroteHeaderFieldInfo)
|
||||||
|
WroteHeaders func()
|
||||||
|
WroteRequest func(TraceWroteRequestInfo)
|
||||||
|
GotFirstResponseByte func()
|
||||||
|
RetryAttemptStart func(TraceRetryAttemptStartInfo)
|
||||||
|
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
|
||||||
|
RetryBackoff func(TraceRetryBackoffInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceGetConnInfo struct {
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceGotConnInfo struct {
|
||||||
|
Conn net.Conn
|
||||||
|
Reused bool
|
||||||
|
WasIdle bool
|
||||||
|
IdleTime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type TracePutIdleConnInfo struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceDNSStartInfo struct {
|
||||||
|
Host string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceDNSDoneInfo struct {
|
||||||
|
Addrs []net.IPAddr
|
||||||
|
Coalesced bool
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceConnectStartInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceConnectDoneInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceTLSHandshakeStartInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
ServerName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceTLSHandshakeDoneInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
ServerName string
|
||||||
|
ConnectionState tls.ConnectionState
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceWroteHeaderFieldInfo struct {
|
||||||
|
Key string
|
||||||
|
Values []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceWroteRequestInfo struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryAttemptStartInfo struct {
|
||||||
|
Attempt int
|
||||||
|
MaxAttempts int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryAttemptDoneInfo struct {
|
||||||
|
Attempt int
|
||||||
|
MaxAttempts int
|
||||||
|
StatusCode int
|
||||||
|
Err error
|
||||||
|
WillRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryBackoffInfo struct {
|
||||||
|
Attempt int
|
||||||
|
Delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type traceState struct {
|
||||||
|
hooks *TraceHooks
|
||||||
|
customTLS atomic.Uint32
|
||||||
|
manualDNSRefs atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTraceState(hooks *TraceHooks) *traceState {
|
||||||
|
if hooks == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &traceState{hooks: hooks}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withTraceState(ctx context.Context, state *traceState) context.Context {
|
||||||
|
if state == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, traceContextKey{}, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTraceState(ctx context.Context) *traceState {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state, _ := ctx.Value(traceContextKey{}).(*traceState)
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) needsHTTPTrace() bool {
|
||||||
|
if t == nil || t.hooks == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h := t.hooks
|
||||||
|
return h.GetConn != nil ||
|
||||||
|
h.GotConn != nil ||
|
||||||
|
h.PutIdleConn != nil ||
|
||||||
|
h.DNSStart != nil ||
|
||||||
|
h.DNSDone != nil ||
|
||||||
|
h.ConnectStart != nil ||
|
||||||
|
h.ConnectDone != nil ||
|
||||||
|
h.TLSHandshakeStart != nil ||
|
||||||
|
h.TLSHandshakeDone != nil ||
|
||||||
|
h.WroteHeaderField != nil ||
|
||||||
|
h.WroteHeaders != nil ||
|
||||||
|
h.WroteRequest != nil ||
|
||||||
|
h.GotFirstResponseByte != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) clientTrace() *httptrace.ClientTrace {
|
||||||
|
if !t.needsHTTPTrace() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h := t.hooks
|
||||||
|
trace := &httptrace.ClientTrace{}
|
||||||
|
if h.GetConn != nil {
|
||||||
|
trace.GetConn = func(hostPort string) {
|
||||||
|
h.GetConn(TraceGetConnInfo{Addr: hostPort})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.GotConn != nil {
|
||||||
|
trace.GotConn = func(info httptrace.GotConnInfo) {
|
||||||
|
h.GotConn(TraceGotConnInfo{
|
||||||
|
Conn: info.Conn,
|
||||||
|
Reused: info.Reused,
|
||||||
|
WasIdle: info.WasIdle,
|
||||||
|
IdleTime: info.IdleTime,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.PutIdleConn != nil {
|
||||||
|
trace.PutIdleConn = func(err error) {
|
||||||
|
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.DNSStart != nil {
|
||||||
|
trace.DNSStart = func(info httptrace.DNSStartInfo) {
|
||||||
|
if t.usesManualDNS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.DNSDone != nil {
|
||||||
|
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
|
||||||
|
if t.usesManualDNS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.DNSDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: append([]net.IPAddr(nil), info.Addrs...),
|
||||||
|
Coalesced: info.Coalesced,
|
||||||
|
Err: info.Err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.ConnectStart != nil {
|
||||||
|
trace.ConnectStart = func(network, addr string) {
|
||||||
|
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.ConnectDone != nil {
|
||||||
|
trace.ConnectDone = func(network, addr string, err error) {
|
||||||
|
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.TLSHandshakeStart != nil {
|
||||||
|
trace.TLSHandshakeStart = func() {
|
||||||
|
if t.usesCustomTLS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.TLSHandshakeDone != nil {
|
||||||
|
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
|
||||||
|
if t.usesCustomTLS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
ConnectionState: state,
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.WroteHeaderField != nil {
|
||||||
|
trace.WroteHeaderField = func(key string, value []string) {
|
||||||
|
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
|
||||||
|
Key: key,
|
||||||
|
Values: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.WroteHeaders != nil {
|
||||||
|
trace.WroteHeaders = h.WroteHeaders
|
||||||
|
}
|
||||||
|
if h.WroteRequest != nil {
|
||||||
|
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
|
||||||
|
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.GotFirstResponseByte != nil {
|
||||||
|
trace.GotFirstResponseByte = h.GotFirstResponseByte
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) markCustomTLS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.customTLS.Store(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) usesCustomTLS() bool {
|
||||||
|
if t == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.customTLS.Load() != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) beginManualDNS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.manualDNSRefs.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) endManualDNS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.manualDNSRefs.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) usesManualDNS() bool {
|
||||||
|
if t == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.manualDNSRefs.Load() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.TLSHandshakeStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.TLSHandshakeDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.DNSStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.DNSDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
|
||||||
|
if hooks == nil || hooks.RetryAttemptStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryAttemptStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
|
||||||
|
if hooks == nil || hooks.RetryAttemptDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryAttemptDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
|
||||||
|
if hooks == nil || hooks.RetryBackoff == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryBackoff(info)
|
||||||
|
}
|
||||||
324
trace_test.go
Normal file
324
trace_test.go
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
events := map[string]int{}
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
GetConn: func(info TraceGetConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["get_conn"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
GotConn: func(info TraceGotConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["got_conn"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["tls_start"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["tls_done"]++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected tls handshake error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
WroteHeaders: func() {
|
||||||
|
mu.Lock()
|
||||||
|
events["wrote_headers"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
WroteRequest: func(info TraceWroteRequestInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["wrote_request"]++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected write error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GotFirstResponseByte: func() {
|
||||||
|
mu.Lock()
|
||||||
|
events["first_byte"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
|
||||||
|
if events[key] == 0 {
|
||||||
|
t.Fatalf("expected trace event %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
tlsStartCount := 0
|
||||||
|
tlsDoneCount := 0
|
||||||
|
var lastInfo TraceTLSHandshakeDoneInfo
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
tlsStartCount++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
tlsDoneCount++
|
||||||
|
lastInfo = info
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetDialTimeout(1500 * time.Millisecond).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if tlsStartCount != 1 {
|
||||||
|
t.Fatalf("tlsStartCount=%d", tlsStartCount)
|
||||||
|
}
|
||||||
|
if tlsDoneCount != 1 {
|
||||||
|
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
|
||||||
|
}
|
||||||
|
if lastInfo.Err != nil {
|
||||||
|
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
|
||||||
|
}
|
||||||
|
if lastInfo.ConnectionState.Version == 0 {
|
||||||
|
t.Fatal("expected tls connection state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
dnsStartCount := 0
|
||||||
|
dnsDoneCount := 0
|
||||||
|
var dnsStartHost string
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsStartCount++
|
||||||
|
dnsStartHost = info.Host
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected dns error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
|
||||||
|
resp, err := NewSimpleRequest(url, http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if dnsStartCount != 1 {
|
||||||
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
||||||
|
}
|
||||||
|
if dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
||||||
|
}
|
||||||
|
if dnsStartHost != "trace.example.test" {
|
||||||
|
t.Fatalf("dnsStartHost=%q", dnsStartHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
connectStartCount := 0
|
||||||
|
connectDoneCount := 0
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
ConnectStart: func(info TraceConnectStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
connectStartCount++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
ConnectDone: func(info TraceConnectDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
connectDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected connect error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
var dialer net.Dialer
|
||||||
|
return dialer.DialContext(context.Background(), network, addr)
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if connectStartCount != 1 {
|
||||||
|
t.Fatalf("connectStartCount=%d", connectStartCount)
|
||||||
|
}
|
||||||
|
if connectDoneCount != 1 {
|
||||||
|
t.Fatalf("connectDoneCount=%d", connectDoneCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksRetryEvents(t *testing.T) {
|
||||||
|
var hits int
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits++
|
||||||
|
if hits == 1 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
starts := 0
|
||||||
|
dones := 0
|
||||||
|
backoffs := 0
|
||||||
|
var finalDone TraceRetryAttemptDoneInfo
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
starts++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dones++
|
||||||
|
finalDone = info
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
RetryBackoff: func(info TraceRetryBackoffInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
backoffs++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if starts != 2 {
|
||||||
|
t.Fatalf("starts=%d", starts)
|
||||||
|
}
|
||||||
|
if dones != 2 {
|
||||||
|
t.Fatalf("dones=%d", dones)
|
||||||
|
}
|
||||||
|
if backoffs != 1 {
|
||||||
|
t.Fatalf("backoffs=%d", backoffs)
|
||||||
|
}
|
||||||
|
if finalDone.WillRetry {
|
||||||
|
t.Fatal("expected final attempt not to retry")
|
||||||
|
}
|
||||||
|
if finalDone.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("final status=%d", finalDone.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
|
||||||
|
var gotErr error
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
gotErr = info.Err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return nil, errors.New("lookup failed")
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected request error")
|
||||||
|
}
|
||||||
|
if gotErr == nil || gotErr.Error() != "lookup failed" {
|
||||||
|
t.Fatalf("gotErr=%v", gotErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
377
transport.go
377
transport.go
@ -1,61 +1,220 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const dynamicTransportCacheMaxEntries = 64
|
||||||
|
|
||||||
|
type dynamicTransportCacheKey struct {
|
||||||
|
proxyKey string
|
||||||
|
dialTimeout time.Duration
|
||||||
|
customIPs string
|
||||||
|
customDNS string
|
||||||
|
tlsServerName string
|
||||||
|
skipVerify bool
|
||||||
|
}
|
||||||
|
|
||||||
// Transport 自定义 Transport(支持请求级配置)
|
// Transport 自定义 Transport(支持请求级配置)
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
base *http.Transport
|
base *http.Transport
|
||||||
mu sync.RWMutex
|
dynamicCache map[dynamicTransportCacheKey]*http.Transport
|
||||||
|
dynamicCacheOrder []dynamicTransportCacheKey
|
||||||
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip 实现 http.RoundTripper 接口
|
// RoundTrip 实现 http.RoundTripper 接口
|
||||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// 确保 base 已初始化
|
t.ensureBase()
|
||||||
if t.base == nil {
|
|
||||||
t.mu.Lock()
|
|
||||||
if t.base == nil {
|
|
||||||
t.base = &http.Transport{
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取请求级别的配置
|
// 提取请求级别的配置
|
||||||
reqCtx := getRequestContext(req.Context())
|
reqCtx := getRequestContext(req.Context())
|
||||||
|
traceState := getTraceState(req.Context())
|
||||||
|
execReq := req
|
||||||
|
execReqCtx := reqCtx
|
||||||
|
var targetAddrs []string
|
||||||
|
|
||||||
// 优先级1:完全自定义的 transport
|
// 优先级1:完全自定义的 transport
|
||||||
if reqCtx.Transport != nil {
|
if execReqCtx.Transport != nil {
|
||||||
return reqCtx.Transport.RoundTrip(req)
|
return execReqCtx.Transport.RoundTrip(execReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先级2:需要动态配置
|
// 优先级2:需要动态配置
|
||||||
if needsDynamicTransport(reqCtx) {
|
if needsDynamicTransport(execReqCtx) {
|
||||||
dynamicTransport := t.buildDynamicTransport(reqCtx)
|
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
|
||||||
return dynamicTransport.RoundTrip(req)
|
if len(targetAddrs) > 0 {
|
||||||
|
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
|
||||||
|
}
|
||||||
|
return dynamicTransport.RoundTrip(execReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先级3:使用基础 transport
|
// 优先级3:使用基础 transport
|
||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
baseTransport := t.base
|
||||||
return t.base.RoundTrip(req)
|
t.mu.RUnlock()
|
||||||
|
if len(targetAddrs) > 0 {
|
||||||
|
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
|
||||||
|
}
|
||||||
|
return baseTransport.RoundTrip(execReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBaseHTTPTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) ensureBase() {
|
||||||
|
if t.base != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.ensureBaseLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) ensureBaseLocked() {
|
||||||
|
if t.base == nil {
|
||||||
|
t.base = newBaseHTTPTransport()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
if key, ok := newDynamicTransportCacheKey(rc); ok {
|
||||||
|
return t.getOrCreateCachedDynamicTransport(key, rc)
|
||||||
|
}
|
||||||
|
return t.buildDynamicTransport(rc, traceState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
|
||||||
|
t.mu.RLock()
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
t.mu.RUnlock()
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
t.mu.RUnlock()
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
t.ensureBaseLocked()
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := buildDynamicTransportFromBase(t.base, rc, nil)
|
||||||
|
if t.dynamicCache == nil {
|
||||||
|
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
|
||||||
|
}
|
||||||
|
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
|
||||||
|
oldestKey := t.dynamicCacheOrder[0]
|
||||||
|
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
|
||||||
|
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
|
||||||
|
oldest.CloseIdleConnections()
|
||||||
|
delete(t.dynamicCache, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.dynamicCache[key] = transport
|
||||||
|
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) resetDynamicTransportCacheLocked() {
|
||||||
|
for _, key := range t.dynamicCacheOrder {
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.dynamicCache = nil
|
||||||
|
t.dynamicCacheOrder = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
|
||||||
|
if rc == nil {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := dynamicTransportCacheKey{
|
||||||
|
proxyKey: normalizeProxyCacheKey(rc.Proxy),
|
||||||
|
dialTimeout: rc.DialTimeout,
|
||||||
|
customIPs: serializeTransportCacheList(rc.CustomIP),
|
||||||
|
customDNS: serializeTransportCacheList(rc.CustomDNS),
|
||||||
|
tlsServerName: effectiveTLSServerName(rc),
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil {
|
||||||
|
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
|
||||||
|
}
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeProxyCacheKey(proxy string) string {
|
||||||
|
if proxy == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
proxyURL, err := parseProxyURL(proxy)
|
||||||
|
if err != nil {
|
||||||
|
return "\x00invalid:" + proxy
|
||||||
|
}
|
||||||
|
return proxyURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeTransportCacheList(values []string) string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var builder strings.Builder
|
||||||
|
for _, value := range values {
|
||||||
|
builder.WriteString(value)
|
||||||
|
builder.WriteByte(0)
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveTLSServerName(rc *RequestContext) string {
|
||||||
|
if rc == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
|
||||||
|
return rc.TLSConfig.ServerName
|
||||||
|
}
|
||||||
|
return rc.TLSServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildDynamicTransport 构建动态 Transport
|
// buildDynamicTransport 构建动态 Transport
|
||||||
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
|
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
t.ensureBase()
|
||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
transport := t.base.Clone()
|
baseTransport := t.base
|
||||||
t.mu.RUnlock()
|
t.mu.RUnlock()
|
||||||
|
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
transport := baseTransport.Clone()
|
||||||
|
|
||||||
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
||||||
if rc.TLSConfig != nil {
|
if rc.TLSConfig != nil {
|
||||||
@ -64,15 +223,33 @@ func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
|
|||||||
|
|
||||||
// 应用代理配置
|
// 应用代理配置
|
||||||
if rc.Proxy != "" {
|
if rc.Proxy != "" {
|
||||||
proxyURL, err := url.Parse(rc.Proxy)
|
proxyURL, err := parseProxyURL(rc.Proxy)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
transport.Proxy = func(*http.Request) (*url.URL, error) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
transport.Proxy = http.ProxyURL(proxyURL)
|
transport.Proxy = http.ProxyURL(proxyURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用自定义 Dial 函数
|
// 应用自定义 Dial 函数
|
||||||
if rc.DialFn != nil {
|
if rc.DialFn != nil {
|
||||||
transport.DialContext = rc.DialFn
|
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
|
||||||
|
dialFn := rc.DialFn
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if traceState.hooks.ConnectStart != nil {
|
||||||
|
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||||||
|
}
|
||||||
|
conn, err := dialFn(ctx, network, addr)
|
||||||
|
if traceState.hooks.ConnectDone != nil {
|
||||||
|
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
transport.DialContext = rc.DialFn
|
||||||
|
}
|
||||||
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
||||||
// 使用默认 Dial 函数(会从 context 读取配置)
|
// 使用默认 Dial 函数(会从 context 读取配置)
|
||||||
transport.DialContext = defaultDialFunc
|
transport.DialContext = defaultDialFunc
|
||||||
@ -93,5 +270,147 @@ func (t *Transport) Base() *http.Transport {
|
|||||||
func (t *Transport) SetBase(base *http.Transport) {
|
func (t *Transport) SetBase(base *http.Transport) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
t.base = base
|
t.base = base
|
||||||
|
t.resetDynamicTransportCacheLocked()
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
|
||||||
|
if req == nil || req.URL == nil || reqCtx == nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
host := req.URL.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
if len(targetAddrs) == 0 {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
execReqCtx := *reqCtx
|
||||||
|
execReqCtx.CustomIP = nil
|
||||||
|
execReqCtx.CustomDNS = nil
|
||||||
|
execReqCtx.LookupIPFn = nil
|
||||||
|
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
|
||||||
|
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
|
||||||
|
execReqCtx.TLSConfigCacheable = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
execCtx := clearTargetResolutionContext(req.Context())
|
||||||
|
execReq := req.Clone(execCtx)
|
||||||
|
execReq.Host = req.Host
|
||||||
|
if len(targetAddrs) == 1 {
|
||||||
|
execReq.URL.Host = targetAddrs[0]
|
||||||
|
return execReq, &execReqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return execReq, &execReqCtx, targetAddrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearTargetResolutionContext(ctx context.Context) context.Context {
|
||||||
|
if v := ctx.Value(ctxKeyRequestContext); v != nil {
|
||||||
|
if rc, ok := v.(*RequestContext); ok && rc != nil {
|
||||||
|
cloned := cloneRequestContext(rc)
|
||||||
|
cloned.CustomIP = nil
|
||||||
|
cloned.CustomDNS = nil
|
||||||
|
cloned.LookupIPFn = nil
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
|
||||||
|
if serverName == "" {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
cloned := cfg.Clone()
|
||||||
|
cloned.ServerName = serverName
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
|
||||||
|
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
|
||||||
|
return rt.RoundTrip(baseReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
|
||||||
|
targetAddrs = targetAddrs[:1]
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for _, targetAddr := range targetAddrs {
|
||||||
|
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(attemptReq)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
|
||||||
|
if req == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !isIdempotentMethod(req.Method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.Body == nil || req.Body == http.NoBody {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return req.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
|
||||||
|
req := baseReq.Clone(baseReq.Context())
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case baseReq.Body == nil || baseReq.Body == http.NoBody:
|
||||||
|
req.Body = baseReq.Body
|
||||||
|
case baseReq.GetBody != nil:
|
||||||
|
body, err := baseReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "clone request body for resolved target")
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
default:
|
||||||
|
req.Body = baseReq.Body
|
||||||
|
}
|
||||||
|
|
||||||
|
req.URL.Host = targetAddr
|
||||||
|
req.Host = baseReq.Host
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|||||||
224
transport_cache_test.go
Normal file
224
transport_cache_test.go
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(&RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
DialTimeout: 2 * time.Second,
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "cache.test",
|
||||||
|
}, nil)
|
||||||
|
second := transport.getDynamicTransport(&RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
DialTimeout: 2 * time.Second,
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "cache.test",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
if first != second {
|
||||||
|
t.Fatal("expected cached dynamic transport to be reused")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(&RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "first.test",
|
||||||
|
}, nil)
|
||||||
|
second := transport.getDynamicTransport(&RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "second.test",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected distinct tls server names to use different transports")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 2 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
reqCtx := &RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
second := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected user tls config to bypass dynamic transport cache")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 0 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
|
||||||
|
first := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.SetDefaultSkipTLSVerify(true)
|
||||||
|
if got := len(transport.dynamicCache); got != 0 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
second := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected cache reset after default tls change")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
|
||||||
|
|
||||||
|
runRequest := func() bool {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
gotConn bool
|
||||||
|
reused bool
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetTraceHooks(&TraceHooks{
|
||||||
|
GotConn: func(info TraceGotConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
gotConn = true
|
||||||
|
reused = info.Reused
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if !gotConn {
|
||||||
|
t.Fatal("expected GotConn trace event")
|
||||||
|
}
|
||||||
|
return reused
|
||||||
|
}
|
||||||
|
|
||||||
|
if runRequest() {
|
||||||
|
t.Fatal("first request unexpectedly reused a connection")
|
||||||
|
}
|
||||||
|
if !runRequest() {
|
||||||
|
t.Fatal("second request did not reuse cached dynamic transport connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
|
||||||
|
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if execReq == req {
|
||||||
|
t.Fatal("expected cloned request for proxy target preparation")
|
||||||
|
}
|
||||||
|
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
|
||||||
|
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
|
||||||
|
}
|
||||||
|
if got := req.URL.Host; got != "proxy-single.test:8443" {
|
||||||
|
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
|
||||||
|
}
|
||||||
|
if len(targetAddrs) != 0 {
|
||||||
|
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
|
||||||
|
}
|
||||||
|
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
|
||||||
|
t.Fatal("expected synthesized tls config for single target proxy request")
|
||||||
|
}
|
||||||
|
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
|
||||||
|
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
|
||||||
|
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
|
||||||
|
t.Fatalf("execReq.URL.Host=%q; want original host", got)
|
||||||
|
}
|
||||||
|
if len(targetAddrs) != 2 {
|
||||||
|
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
|
||||||
|
}
|
||||||
|
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
|
||||||
|
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
16
types.go
16
types.go
@ -53,6 +53,7 @@ type NetworkConfig struct {
|
|||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
Config *tls.Config // TLS 配置
|
Config *tls.Config // TLS 配置
|
||||||
SkipVerify bool // 跳过证书验证
|
SkipVerify bool // 跳过证书验证
|
||||||
|
ServerName string // 显式 TLS ServerName/SNI 覆盖
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNSConfig DNS 配置
|
// DNSConfig DNS 配置
|
||||||
@ -62,8 +63,19 @@ type DNSConfig struct {
|
|||||||
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
|
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type bodyMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
bodyModeUnset bodyMode = iota
|
||||||
|
bodyModeBytes
|
||||||
|
bodyModeReader
|
||||||
|
bodyModeForm
|
||||||
|
bodyModeMultipart
|
||||||
|
)
|
||||||
|
|
||||||
// BodyConfig 请求体配置
|
// BodyConfig 请求体配置
|
||||||
type BodyConfig struct {
|
type BodyConfig struct {
|
||||||
|
Mode bodyMode // 当前 body 来源模式
|
||||||
Bytes []byte // 原始字节
|
Bytes []byte // 原始字节
|
||||||
Reader io.Reader // 数据流
|
Reader io.Reader // 数据流
|
||||||
FormData map[string][]string // 表单数据
|
FormData map[string][]string // 表单数据
|
||||||
@ -82,6 +94,7 @@ type RequestConfig struct {
|
|||||||
|
|
||||||
// 其他配置
|
// 其他配置
|
||||||
BasicAuth [2]string // Basic 认证
|
BasicAuth [2]string // Basic 认证
|
||||||
|
Host string // 显式 Host 头覆盖
|
||||||
ContentLength int64 // 手动设置的 Content-Length
|
ContentLength int64 // 手动设置的 Content-Length
|
||||||
AutoCalcContentLength bool // 自动计算 Content-Length
|
AutoCalcContentLength bool // 自动计算 Content-Length
|
||||||
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
|
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
|
||||||
@ -104,6 +117,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
|
|||||||
TLS: TLSConfig{
|
TLS: TLSConfig{
|
||||||
Config: cloneTLSConfig(c.TLS.Config),
|
Config: cloneTLSConfig(c.TLS.Config),
|
||||||
SkipVerify: c.TLS.SkipVerify,
|
SkipVerify: c.TLS.SkipVerify,
|
||||||
|
ServerName: c.TLS.ServerName,
|
||||||
},
|
},
|
||||||
DNS: DNSConfig{
|
DNS: DNSConfig{
|
||||||
CustomIP: cloneStringSlice(c.DNS.CustomIP),
|
CustomIP: cloneStringSlice(c.DNS.CustomIP),
|
||||||
@ -111,6 +125,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
|
|||||||
LookupFunc: c.DNS.LookupFunc,
|
LookupFunc: c.DNS.LookupFunc,
|
||||||
},
|
},
|
||||||
Body: BodyConfig{
|
Body: BodyConfig{
|
||||||
|
Mode: c.Body.Mode,
|
||||||
Bytes: cloneBytes(c.Body.Bytes),
|
Bytes: cloneBytes(c.Body.Bytes),
|
||||||
Reader: c.Body.Reader, // Reader 不可克隆
|
Reader: c.Body.Reader, // Reader 不可克隆
|
||||||
FormData: cloneStringMapSlice(c.Body.FormData),
|
FormData: cloneStringMapSlice(c.Body.FormData),
|
||||||
@ -120,6 +135,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
|
|||||||
Cookies: cloneCookies(c.Cookies),
|
Cookies: cloneCookies(c.Cookies),
|
||||||
Queries: cloneStringMapSlice(c.Queries),
|
Queries: cloneStringMapSlice(c.Queries),
|
||||||
BasicAuth: c.BasicAuth,
|
BasicAuth: c.BasicAuth,
|
||||||
|
Host: c.Host,
|
||||||
ContentLength: c.ContentLength,
|
ContentLength: c.ContentLength,
|
||||||
AutoCalcContentLength: c.AutoCalcContentLength,
|
AutoCalcContentLength: c.AutoCalcContentLength,
|
||||||
MaxRespBodyBytes: c.MaxRespBodyBytes,
|
MaxRespBodyBytes: c.MaxRespBodyBytes,
|
||||||
|
|||||||
51
utils.go
51
utils.go
@ -101,24 +101,31 @@ func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
|
|||||||
}
|
}
|
||||||
newCookies := make([]*http.Cookie, len(cookies))
|
newCookies := make([]*http.Cookie, len(cookies))
|
||||||
for i, c := range cookies {
|
for i, c := range cookies {
|
||||||
newCookies[i] = &http.Cookie{
|
newCookies[i] = cloneCookie(c)
|
||||||
Name: c.Name,
|
|
||||||
Value: c.Value,
|
|
||||||
Path: c.Path,
|
|
||||||
Domain: c.Domain,
|
|
||||||
Expires: c.Expires,
|
|
||||||
RawExpires: c.RawExpires,
|
|
||||||
MaxAge: c.MaxAge,
|
|
||||||
Secure: c.Secure,
|
|
||||||
HttpOnly: c.HttpOnly,
|
|
||||||
SameSite: c.SameSite,
|
|
||||||
Raw: c.Raw,
|
|
||||||
Unparsed: append([]string(nil), c.Unparsed...),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return newCookies
|
return newCookies
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneCookie(cookie *http.Cookie) *http.Cookie {
|
||||||
|
if cookie == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: cookie.Name,
|
||||||
|
Value: cookie.Value,
|
||||||
|
Path: cookie.Path,
|
||||||
|
Domain: cookie.Domain,
|
||||||
|
Expires: cookie.Expires,
|
||||||
|
RawExpires: cookie.RawExpires,
|
||||||
|
MaxAge: cookie.MaxAge,
|
||||||
|
Secure: cookie.Secure,
|
||||||
|
HttpOnly: cookie.HttpOnly,
|
||||||
|
SameSite: cookie.SameSite,
|
||||||
|
Raw: cookie.Raw,
|
||||||
|
Unparsed: append([]string(nil), cookie.Unparsed...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cloneStringMapSlice 克隆 map[string][]string
|
// cloneStringMapSlice 克隆 map[string][]string
|
||||||
func cloneStringMapSlice(m map[string][]string) map[string][]string {
|
func cloneStringMapSlice(m map[string][]string) map[string][]string {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
@ -171,8 +178,8 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|||||||
|
|
||||||
// copyWithProgress 带进度的复制
|
// copyWithProgress 带进度的复制
|
||||||
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
|
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
|
||||||
if progress == nil {
|
if ctx == nil {
|
||||||
return io.Copy(dst, src)
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
var written int64
|
var written int64
|
||||||
@ -190,8 +197,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
|
|||||||
nw, ew := dst.Write(buf[:nr])
|
nw, ew := dst.Write(buf[:nr])
|
||||||
if nw > 0 {
|
if nw > 0 {
|
||||||
written += int64(nw)
|
written += int64(nw)
|
||||||
// 同步调用进度回调(不使用 goroutine)
|
if progress != nil {
|
||||||
progress(filename, written, total)
|
// 同步调用进度回调(不使用 goroutine)
|
||||||
|
progress(filename, written, total)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if ew != nil {
|
if ew != nil {
|
||||||
return written, ew
|
return written, ew
|
||||||
@ -202,8 +211,10 @@ func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filenam
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// 最后一次进度回调
|
if progress != nil {
|
||||||
progress(filename, written, total)
|
// 最后一次进度回调
|
||||||
|
progress(filename, written, total)
|
||||||
|
}
|
||||||
return written, nil
|
return written, nil
|
||||||
}
|
}
|
||||||
return written, err
|
return written, err
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user