feat: 增加请求级 trace 摘要与诊断能力

- 新增 TraceRecorder 和 TraceSummary,汇总 DNS、连接、TLS、写请求、首包等关键事件
  - 为请求执行链接入结构化 trace hooks,补充标准路径与动态路径的 TLS 元信息
  - 增加 Request.TraceSummary() 和 Response.TraceSummary(),提供请求级与响应级摘要快照
  - 修复共享 TraceRecorder 在 Client 默认选项、Clone 和请求复用场景下的状态串扰问题
  - 修复 Response.TraceSummary() 回读 Request 最近状态导致的非快照语义
  - 收口自定义 DialFunc 下的 TLS trace 元数据,避免伪造连接地址
  - 补充 trace 相关回归测试,覆盖 HTTPS、DNS/Connect、连接复用、共享 recorder、响应快照和自定义拨号场景
  - 更新 README,补充 trace、Host 与 TLSServerName 的行为说明
This commit is contained in:
兔子 2026-04-20 17:54:43 +08:00
parent 732e81316c
commit 2f4c7158cf
Signed by: b612
GPG Key ID: 99DD2222B612B612
11 changed files with 1471 additions and 32 deletions

View File

@ -6,6 +6,7 @@
- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
- 支持请求级 `Host` 覆盖、显式 `TLSServerName/SNI` 控制,以及结构化 trace 回调 / 摘要
- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
- 响应体大小限制,避免一次性读取过大内容
- 错误分类辅助:`ClassifyError``IsTimeout``IsDNS``IsTLS``IsProxy``IsCanceled`
@ -19,9 +20,17 @@
- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
- 支持 `Get``Post``Put``Delete``Head``Patch``Options``Trace``Connect`
- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
- 支持显式 `Host` 覆盖与 `TLSServerName` 设置,便于直连 IP、虚拟主机和证书校验场景分离控制
- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
### Trace 与诊断
- 支持 `TraceHooks`,可接收 DNS、建连、TLS 握手、写请求、首包等结构化事件
- 支持 `TraceRecorder` / `TraceSummary`,用于汇总一次请求的关键网络过程和 TLS 摘要
- `Request.TraceSummary()` 返回该请求最近一次执行的摘要快照,`Response.TraceSummary()` 返回当前响应对应的摘要快照
- 若多个请求共享同一个 `TraceRecorder`,其 `Summary()` 表示最近一次完成请求的摘要
### 超时与重试
- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
@ -98,8 +107,10 @@ func main() {
- `NewClient``NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`
- `SetHost` / `WithHost` 只覆盖 HTTP 请求的 `Host`;如需单独控制 TLS SNI 或证书校验名,应配合 `SetTLSServerName` / `WithTLSServerName` 使用。
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader``WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
- 绑定到请求上的 `TraceRecorder` 用于发布已完成请求的摘要;请求执行中的中间状态不保证通过共享 recorder 实时可见。
## 稳定性说明

View File

@ -46,6 +46,11 @@ func WithTraceHooks(hooks *TraceHooks) RequestOpt {
return requestOptFromMutation(mutateTraceHooks(hooks))
}
// WithTraceRecorder 设置请求级 trace 摘要记录器。
func WithTraceRecorder(recorder *TraceRecorder) RequestOpt {
return requestOptFromMutation(mutateTraceRecorder(recorder))
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return requestOptFromMutation(mutateSkipTLSVerify(skip))

View File

@ -17,13 +17,16 @@ type Request struct {
method string
err error // 累积的错误
config *RequestConfig
client *Client
httpClient *http.Client
httpReq *http.Request
retry *retryPolicy
traceHooks *TraceHooks
traceState *traceState
config *RequestConfig
client *Client
httpClient *http.Client
httpReq *http.Request
retry *retryPolicy
traceHooks *TraceHooks
traceRecorder *TraceRecorder
traceRun *TraceRecorder
lastTraceSummary *TraceSummary
traceState *traceState
applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改)
@ -81,6 +84,7 @@ func (r *Request) invalidatePreparedState() {
r.cancel = nil
}
r.execCtx = nil
r.traceRun = nil
r.traceState = nil
r.httpClient = nil
@ -273,18 +277,19 @@ func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts .
// Clone 克隆请求
func (r *Request) Clone() *Request {
cloned := &Request{
ctx: r.ctx,
url: r.url,
method: r.method,
err: r.err,
config: r.config.Clone(),
client: r.client,
httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
traceHooks: r.traceHooks,
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
ctx: r.ctx,
url: r.url,
method: r.method,
err: r.err,
config: r.config.Clone(),
client: r.client,
httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
traceHooks: r.traceHooks,
traceRecorder: r.traceRecorder,
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
rawSourceExternal: r.rawSourceExternal,
}
@ -477,11 +482,20 @@ func (r *Request) Do() (*Response, error) {
return nil, r.err
}
r.startTraceExecution()
var (
resp *Response
err error
)
if r.hasRetryPolicy() {
return r.doWithRetry()
resp, err = r.doWithRetry()
} else {
resp, err = r.doOnce()
}
return r.doOnce()
r.finishTraceExecution(resp)
return resp, err
}
func (r *Request) doOnce() (*Response, error) {
@ -493,6 +507,9 @@ func (r *Request) doOnce() (*Response, error) {
if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request")
}
if r.traceRun != nil {
r.traceRun.observePreparedRequest(r.httpReq)
}
// 执行请求
httpResp, err := r.httpClient.Do(r.httpReq)
@ -508,6 +525,9 @@ func (r *Request) doOnce() (*Response, error) {
body: &Body{},
}, wrapError(err, "do request")
}
if r.traceRun != nil {
r.traceRun.observeResponse(httpResp)
}
rawBody := httpResp.Body
if r.cancel != nil {

View File

@ -122,6 +122,13 @@ func mutateTraceHooks(hooks *TraceHooks) requestMutation {
}
}
func mutateTraceRecorder(recorder *TraceRecorder) requestMutation {
return func(r *Request) error {
r.traceRecorder = recorder
return nil
}
}
func mutateSkipTLSVerify(skip bool) requestMutation {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip

View File

@ -194,6 +194,36 @@ func (r *Request) applyBody(execCtx context.Context) error {
return nil
}
func buildTraceTLSHandshakeInfo(req *http.Request, execCtx context.Context, defaultServerName string) TraceTLSHandshakeStartInfo {
if req == nil || req.URL == nil || req.URL.Scheme != "https" {
return TraceTLSHandshakeStartInfo{}
}
reqCtx := getRequestContext(execCtx)
info := TraceTLSHandshakeStartInfo{}
// 自定义 DialFunc 的真实落点由调用方决定,这里只在默认拨号路径下预填地址,避免 trace 元信息误导。
if reqCtx == nil || reqCtx.DialFn == nil {
info.Network = "tcp"
info.Addr = req.URL.Host
}
if reqCtx != nil {
if reqCtx.TLSConfig != nil && reqCtx.TLSConfig.ServerName != "" {
info.ServerName = reqCtx.TLSConfig.ServerName
} else if reqCtx.TLSServerName != "" {
info.ServerName = reqCtx.TLSServerName
}
}
if info.ServerName == "" {
if defaultServerName != "" {
info.ServerName = defaultServerName
} else {
info.ServerName = req.URL.Hostname()
}
}
return info
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() (err error) {
if r.applied {
@ -215,8 +245,10 @@ func (r *Request) prepare() (err error) {
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
var traceState *traceState
if r.traceHooks != nil {
traceState = newTraceState(r.traceHooks)
traceHooks := composeTraceHooks(r.traceHooks, traceRecorderHooks(r.traceRun))
if traceHooks != nil {
traceState = newTraceState(traceHooks)
traceState.setDefaultTLSHandshakeInfo(buildTraceTLSHandshakeInfo(r.httpReq, execCtx, defaultTLSServerName))
execCtx = withTraceState(execCtx, traceState)
if clientTrace := traceState.clientTrace(); clientTrace != nil {
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)

View File

@ -4,3 +4,53 @@ package starnet
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
return r.applyMutation(mutateTraceHooks(hooks))
}
// SetTraceRecorder 设置请求级 trace 摘要记录器。
// 记录器会保存最近一次已完成请求的摘要;若多个请求共享同一个记录器,则以最后一次完成的请求为准。
func (r *Request) SetTraceRecorder(recorder *TraceRecorder) *Request {
return r.applyMutation(mutateTraceRecorder(recorder))
}
// TraceSummary 返回当前请求最近一次执行的 trace 摘要快照。
func (r *Request) TraceSummary() *TraceSummary {
if r == nil || r.lastTraceSummary == nil {
return nil
}
summary := cloneTraceSummary(*r.lastTraceSummary)
return &summary
}
func (r *Request) startTraceExecution() {
if r == nil {
return
}
r.traceRun = nil
if r.traceRecorder == nil {
return
}
r.traceRun = r.traceRecorder.forkExecution()
if r.traceRun != nil {
r.traceRun.startRequest()
}
}
func (r *Request) finishTraceExecution(resp *Response) {
if r == nil {
return
}
if r.traceRun == nil {
r.lastTraceSummary = nil
if resp != nil {
resp.traceSummary = nil
}
return
}
summary := r.traceRun.Summary()
r.lastTraceSummary = cloneTraceSummaryPtr(summary)
r.traceRecorder.publishSummary(summary)
if resp != nil {
resp.traceSummary = cloneTraceSummaryPtr(summary)
}
r.traceRun = nil
}

View File

@ -11,10 +11,11 @@ import (
// Response HTTP 响应
type Response struct {
*http.Response
request *Request
httpClient *http.Client
cancel func()
body *Body
request *Request
httpClient *http.Client
cancel func()
body *Body
traceSummary *TraceSummary
}
// Body 响应体
@ -47,6 +48,15 @@ func (r *Response) Request() *Request {
return r.request
}
// TraceSummary 获取当前响应对应的 trace 摘要快照。
func (r *Response) TraceSummary() *TraceSummary {
if r == nil || r.traceSummary == nil {
return nil
}
summary := cloneTraceSummary(*r.traceSummary)
return &summary
}
// Body 获取响应体
func (r *Response) Body() *Body {
return r.body

View File

@ -312,6 +312,8 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
attempt.traceRun = r.traceRun
attempt.lastTraceSummary = nil
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {

274
trace.go
View File

@ -110,9 +110,10 @@ type TraceRetryBackoffInfo struct {
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo
}
func newTraceState(hooks *TraceHooks) *traceState {
@ -122,6 +123,20 @@ func newTraceState(hooks *TraceHooks) *traceState {
return &traceState{hooks: hooks}
}
func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) {
if t == nil {
return
}
t.defaultTLSHandshakeInfo = info
}
func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo {
if t == nil {
return TraceTLSHandshakeStartInfo{}
}
return t.defaultTLSHandshakeInfo
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
@ -219,7 +234,7 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo())
}
}
if h.TLSHandshakeDone != nil {
@ -227,7 +242,11 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace {
if t.usesCustomTLS() {
return
}
info := t.getDefaultTLSHandshakeInfo()
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: info.Network,
Addr: info.Addr,
ServerName: info.ServerName,
ConnectionState: state,
Err: err,
})
@ -237,7 +256,7 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: value,
Values: append([]string(nil), value...),
})
}
}
@ -338,3 +357,248 @@ func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
}
hooks.RetryBackoff(info)
}
func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks {
if recorder == nil {
return nil
}
return recorder.Hooks()
}
func composeTraceHooks(first, second *TraceHooks) *TraceHooks {
switch {
case first == nil:
return second
case second == nil:
return first
}
return &TraceHooks{
GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn),
GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn),
PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn),
DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart),
DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone),
ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart),
ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone),
TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart),
TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone),
WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField),
WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders),
WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest),
GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte),
RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart),
RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone),
RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff),
}
}
func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGetConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGotConnInfo) {
first(info)
second(info)
}
}
}
func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TracePutIdleConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteHeaderFieldInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteRequestInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryBackoffInfo) {
first(info)
second(info)
}
}
}
func composeTraceSimpleHook(first, second func()) func() {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func() {
first()
second()
}
}
}

550
trace_summary.go Normal file
View File

@ -0,0 +1,550 @@
package starnet
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
)
// TraceSummary 是一次请求执行的 trace 摘要。
type TraceSummary struct {
Method string
URL string
StartedAt time.Time
ResponseAt time.Time
StatusCode int
ResponseProto string
Conn TraceConnSummary
DNS *TraceDNSSummary
DNSEvents []TraceDNSSummary
Connect []TraceConnectSummary
TLS *TraceTLSSummary
RequestWrittenAt time.Time
RequestWriteErr error
FirstResponseByteAt time.Time
}
// TraceConnSummary 是连接复用与套接字信息摘要。
type TraceConnSummary struct {
Addr string
LocalAddr string
RemoteAddr string
Reused bool
WasIdle bool
IdleTime time.Duration
}
// TraceDNSSummary 是 DNS 解析摘要。
type TraceDNSSummary struct {
Host string
Addrs []string
Coalesced bool
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
}
// TraceConnectSummary 是单次建连尝试摘要。
type TraceConnectSummary struct {
Network string
Addr string
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
}
// TraceTLSSummary 是 TLS 握手与连接状态摘要。
type TraceTLSSummary struct {
Network string
Addr string
ServerName string
Version uint16
VersionName string
CipherSuite uint16
CipherSuiteName string
CurveID tls.CurveID
CurveName string
NegotiatedProtocol string
DidResume bool
ECHAccepted bool
VerifiedChains int
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
PeerCertificates []TraceCertificateSummary
}
// TraceCertificateSummary 是单张证书的关键信息摘要。
type TraceCertificateSummary struct {
Subject string
Issuer string
DNSNames []string
IPAddresses []string
}
// TraceRecorder 聚合最近一次发布的 trace 摘要。
// 通过 Request/Client 绑定时starnet 会为每次执行创建私有运行态并在完成后发布摘要;
// 直接使用 Hooks() 时,调用方仍需自行管理 Reset 与生命周期。
type TraceRecorder struct {
mu sync.Mutex
summary TraceSummary
pendingDNS []TraceDNSSummary
pendingConnectStarts map[string][]time.Time
pendingTLSStart time.Time
hooks *TraceHooks
}
// NewTraceRecorder 创建请求级 trace 记录器。
func NewTraceRecorder() *TraceRecorder {
recorder := &TraceRecorder{}
recorder.hooks = &TraceHooks{
GetConn: recorder.onGetConn,
GotConn: recorder.onGotConn,
DNSStart: recorder.onDNSStart,
DNSDone: recorder.onDNSDone,
ConnectStart: recorder.onConnectStart,
ConnectDone: recorder.onConnectDone,
TLSHandshakeStart: recorder.onTLSHandshakeStart,
TLSHandshakeDone: recorder.onTLSHandshakeDone,
WroteRequest: recorder.onWroteRequest,
GotFirstResponseByte: recorder.onGotFirstResponseByte,
}
return recorder
}
// Hooks 返回可挂到请求上的底层 trace hooks。
func (r *TraceRecorder) Hooks() *TraceHooks {
if r == nil {
return nil
}
if r.hooks == nil {
r.hooks = &TraceHooks{
GetConn: r.onGetConn,
GotConn: r.onGotConn,
DNSStart: r.onDNSStart,
DNSDone: r.onDNSDone,
ConnectStart: r.onConnectStart,
ConnectDone: r.onConnectDone,
TLSHandshakeStart: r.onTLSHandshakeStart,
TLSHandshakeDone: r.onTLSHandshakeDone,
WroteRequest: r.onWroteRequest,
GotFirstResponseByte: r.onGotFirstResponseByte,
}
}
return r.hooks
}
// Reset 清空当前摘要和内部状态。
func (r *TraceRecorder) Reset() {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.resetLocked()
}
// Summary 返回当前 trace 摘要的快照。
func (r *TraceRecorder) Summary() TraceSummary {
if r == nil {
return TraceSummary{}
}
r.mu.Lock()
defer r.mu.Unlock()
return cloneTraceSummary(r.summary)
}
func (r *TraceRecorder) forkExecution() *TraceRecorder {
if r == nil {
return nil
}
return NewTraceRecorder()
}
func (r *TraceRecorder) publishSummary(summary TraceSummary) {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.summary = cloneTraceSummary(summary)
}
func (r *TraceRecorder) startRequest() {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.resetLocked()
r.summary.StartedAt = time.Now()
}
func (r *TraceRecorder) observePreparedRequest(req *http.Request) {
if r == nil || req == nil || req.URL == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.ensureStartedLocked(time.Now())
r.summary.Method = req.Method
r.summary.URL = req.URL.String()
}
func (r *TraceRecorder) observeResponse(resp *http.Response) {
if r == nil || resp == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.ResponseAt = now
r.summary.StatusCode = resp.StatusCode
r.summary.ResponseProto = resp.Proto
if resp.TLS != nil {
r.summary.TLS = mergeTraceTLSSummary(r.summary.TLS, *resp.TLS)
}
}
func (r *TraceRecorder) ensureStartedLocked(now time.Time) {
if r.summary.StartedAt.IsZero() {
r.summary.StartedAt = now
}
}
func (r *TraceRecorder) resetLocked() {
r.summary = TraceSummary{}
r.pendingDNS = nil
r.pendingTLSStart = time.Time{}
if len(r.pendingConnectStarts) == 0 {
r.pendingConnectStarts = nil
return
}
for key := range r.pendingConnectStarts {
delete(r.pendingConnectStarts, key)
}
r.pendingConnectStarts = nil
}
func (r *TraceRecorder) onGetConn(info TraceGetConnInfo) {
r.mu.Lock()
defer r.mu.Unlock()
r.ensureStartedLocked(time.Now())
r.summary.Conn.Addr = info.Addr
}
func (r *TraceRecorder) onGotConn(info TraceGotConnInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.Conn.Reused = info.Reused
r.summary.Conn.WasIdle = info.WasIdle
r.summary.Conn.IdleTime = info.IdleTime
if info.Conn != nil {
r.summary.Conn.LocalAddr = traceAddrString(info.Conn.LocalAddr())
r.summary.Conn.RemoteAddr = traceAddrString(info.Conn.RemoteAddr())
}
}
func (r *TraceRecorder) onDNSStart(info TraceDNSStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
dns := TraceDNSSummary{
Host: info.Host,
StartedAt: now,
}
r.pendingDNS = append(r.pendingDNS, dns)
copyDNS := dns
r.summary.DNS = &copyDNS
}
func (r *TraceRecorder) onDNSDone(info TraceDNSDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
dns := TraceDNSSummary{
Host: "",
Addrs: traceIPAddrsToStrings(info.Addrs),
Coalesced: info.Coalesced,
CompletedAt: now,
Err: info.Err,
}
if len(r.pendingDNS) > 0 {
dns.Host = r.pendingDNS[0].Host
dns.StartedAt = r.pendingDNS[0].StartedAt
if len(r.pendingDNS) == 1 {
r.pendingDNS = nil
} else {
r.pendingDNS = append([]TraceDNSSummary(nil), r.pendingDNS[1:]...)
}
} else if r.summary.DNS != nil {
dns.Host = r.summary.DNS.Host
dns.StartedAt = r.summary.DNS.StartedAt
}
if !dns.StartedAt.IsZero() {
dns.Duration = now.Sub(dns.StartedAt)
}
r.summary.DNSEvents = append(r.summary.DNSEvents, dns)
copyDNS := dns
r.summary.DNS = &copyDNS
}
func (r *TraceRecorder) onConnectStart(info TraceConnectStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
if r.pendingConnectStarts == nil {
r.pendingConnectStarts = make(map[string][]time.Time)
}
key := traceConnectKey(info.Network, info.Addr)
r.pendingConnectStarts[key] = append(r.pendingConnectStarts[key], now)
r.summary.Connect = append(r.summary.Connect, TraceConnectSummary{
Network: info.Network,
Addr: info.Addr,
StartedAt: now,
})
}
func (r *TraceRecorder) onConnectDone(info TraceConnectDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
start := time.Time{}
key := traceConnectKey(info.Network, info.Addr)
if starts := r.pendingConnectStarts[key]; len(starts) > 0 {
start = starts[0]
if len(starts) == 1 {
delete(r.pendingConnectStarts, key)
} else {
r.pendingConnectStarts[key] = starts[1:]
}
}
connect := TraceConnectSummary{
Network: info.Network,
Addr: info.Addr,
StartedAt: start,
CompletedAt: now,
Err: info.Err,
}
if !start.IsZero() {
connect.Duration = now.Sub(start)
}
for index := len(r.summary.Connect) - 1; index >= 0; index-- {
item := &r.summary.Connect[index]
if item.Network != info.Network || item.Addr != info.Addr || !item.CompletedAt.IsZero() {
continue
}
item.CompletedAt = now
item.Duration = connect.Duration
item.Err = info.Err
return
}
r.summary.Connect = append(r.summary.Connect, connect)
}
func (r *TraceRecorder) onTLSHandshakeStart(info TraceTLSHandshakeStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.pendingTLSStart = now
r.summary.TLS = &TraceTLSSummary{
Network: info.Network,
Addr: info.Addr,
ServerName: info.ServerName,
StartedAt: now,
}
}
func (r *TraceRecorder) onTLSHandshakeDone(info TraceTLSHandshakeDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
var tlsSummary *TraceTLSSummary
if r.summary.TLS != nil {
copied := *r.summary.TLS
tlsSummary = &copied
} else {
tlsSummary = &TraceTLSSummary{}
}
if tlsSummary.Network == "" {
tlsSummary.Network = info.Network
}
if tlsSummary.Addr == "" {
tlsSummary.Addr = info.Addr
}
if tlsSummary.ServerName == "" {
tlsSummary.ServerName = info.ServerName
}
if tlsSummary.StartedAt.IsZero() {
tlsSummary.StartedAt = r.pendingTLSStart
}
tlsSummary.CompletedAt = now
if !tlsSummary.StartedAt.IsZero() {
tlsSummary.Duration = now.Sub(tlsSummary.StartedAt)
}
tlsSummary.Err = info.Err
tlsSummary = mergeTraceTLSSummary(tlsSummary, info.ConnectionState)
if tlsSummary.ServerName == "" {
tlsSummary.ServerName = info.ServerName
}
if tlsSummary.Addr == "" {
tlsSummary.Addr = info.Addr
}
if tlsSummary.Network == "" {
tlsSummary.Network = info.Network
}
r.pendingTLSStart = time.Time{}
r.summary.TLS = tlsSummary
}
func (r *TraceRecorder) onWroteRequest(info TraceWroteRequestInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.RequestWrittenAt = now
r.summary.RequestWriteErr = info.Err
}
func (r *TraceRecorder) onGotFirstResponseByte() {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.FirstResponseByteAt = now
}
func traceAddrString(addr net.Addr) string {
if addr == nil {
return ""
}
return addr.String()
}
func traceConnectKey(network, addr string) string {
return network + "\x00" + addr
}
func traceIPAddrsToStrings(addrs []net.IPAddr) []string {
if len(addrs) == 0 {
return nil
}
out := make([]string, 0, len(addrs))
for _, addr := range addrs {
out = append(out, addr.String())
}
return out
}
func mergeTraceTLSSummary(summary *TraceTLSSummary, state tls.ConnectionState) *TraceTLSSummary {
if summary == nil {
summary = &TraceTLSSummary{}
}
if state.Version != 0 {
summary.Version = state.Version
summary.VersionName = tls.VersionName(state.Version)
}
if state.CipherSuite != 0 {
summary.CipherSuite = state.CipherSuite
summary.CipherSuiteName = tls.CipherSuiteName(state.CipherSuite)
}
if state.CurveID != 0 {
summary.CurveID = state.CurveID
summary.CurveName = state.CurveID.String()
}
if state.ServerName != "" {
summary.ServerName = state.ServerName
}
if state.NegotiatedProtocol != "" {
summary.NegotiatedProtocol = state.NegotiatedProtocol
}
summary.DidResume = state.DidResume
summary.ECHAccepted = state.ECHAccepted
summary.VerifiedChains = len(state.VerifiedChains)
if len(state.PeerCertificates) > 0 {
summary.PeerCertificates = make([]TraceCertificateSummary, 0, len(state.PeerCertificates))
for _, cert := range state.PeerCertificates {
certSummary := TraceCertificateSummary{
Subject: cert.Subject.String(),
Issuer: cert.Issuer.String(),
DNSNames: append([]string(nil), cert.DNSNames...),
}
if len(cert.IPAddresses) > 0 {
certSummary.IPAddresses = make([]string, 0, len(cert.IPAddresses))
for _, ip := range cert.IPAddresses {
certSummary.IPAddresses = append(certSummary.IPAddresses, ip.String())
}
}
summary.PeerCertificates = append(summary.PeerCertificates, certSummary)
}
}
return summary
}
func cloneTraceSummary(summary TraceSummary) TraceSummary {
cloned := summary
if summary.DNS != nil {
dns := *summary.DNS
dns.Addrs = append([]string(nil), summary.DNS.Addrs...)
cloned.DNS = &dns
}
if len(summary.DNSEvents) > 0 {
cloned.DNSEvents = make([]TraceDNSSummary, 0, len(summary.DNSEvents))
for _, dns := range summary.DNSEvents {
cloned.DNSEvents = append(cloned.DNSEvents, TraceDNSSummary{
Host: dns.Host,
Addrs: append([]string(nil), dns.Addrs...),
Coalesced: dns.Coalesced,
StartedAt: dns.StartedAt,
CompletedAt: dns.CompletedAt,
Duration: dns.Duration,
Err: dns.Err,
})
}
}
if len(summary.Connect) > 0 {
cloned.Connect = append([]TraceConnectSummary(nil), summary.Connect...)
}
if summary.TLS != nil {
tlsSummary := *summary.TLS
if len(summary.TLS.PeerCertificates) > 0 {
tlsSummary.PeerCertificates = make([]TraceCertificateSummary, 0, len(summary.TLS.PeerCertificates))
for _, cert := range summary.TLS.PeerCertificates {
tlsSummary.PeerCertificates = append(tlsSummary.PeerCertificates, TraceCertificateSummary{
Subject: cert.Subject,
Issuer: cert.Issuer,
DNSNames: append([]string(nil), cert.DNSNames...),
IPAddresses: append([]string(nil), cert.IPAddresses...),
})
}
}
cloned.TLS = &tlsSummary
}
return cloned
}
func cloneTraceSummaryPtr(summary TraceSummary) *TraceSummary {
cloned := cloneTraceSummary(summary)
return &cloned
}

488
trace_summary_test.go Normal file
View File

@ -0,0 +1,488 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"strconv"
"strings"
"testing"
)
func TestTraceRecorderCapturesHTTPSummary(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
req := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.Method != http.MethodGet {
t.Fatalf("method=%q", summary.Method)
}
if summary.URL != server.URL {
t.Fatalf("url=%q", summary.URL)
}
if summary.StatusCode != http.StatusOK {
t.Fatalf("status=%d", summary.StatusCode)
}
if summary.ResponseProto == "" {
t.Fatal("expected response proto")
}
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected request write timestamp")
}
if summary.FirstResponseByteAt.IsZero() {
t.Fatal("expected first response byte timestamp")
}
if summary.Conn.Addr == "" {
t.Fatal("expected get-conn target address")
}
if summary.TLS == nil {
t.Fatal("expected tls summary")
}
tlsSummary := summary.TLS
if tlsSummary.Version == 0 || tlsSummary.VersionName == "" {
t.Fatalf("unexpected tls version summary: %+v", tlsSummary)
}
if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" {
t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary)
}
if tlsSummary.ServerName == "" {
t.Fatal("expected tls server name")
}
if resp.TLS == nil {
t.Fatal("expected response TLS state")
}
if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol {
t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol)
}
if len(tlsSummary.PeerCertificates) == 0 {
t.Fatal("expected certificate summaries")
}
leaf := tlsSummary.PeerCertificates[0]
if leaf.Subject == "" || leaf.Issuer == "" {
t.Fatalf("unexpected leaf certificate summary: %+v", leaf)
}
if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 {
t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf)
}
if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("response trace summary=%+v", got)
}
}
func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
recorder := NewTraceRecorder()
targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(targetURL, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.DNS == nil {
t.Fatal("expected dns summary")
}
if summary.DNS.Host != "trace-summary.example.test" {
t.Fatalf("dns host=%q", summary.DNS.Host)
}
if len(summary.DNS.Addrs) == 0 {
t.Fatal("expected resolved addresses")
}
if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) {
t.Fatalf("dns addrs=%v", summary.DNS.Addrs)
}
if summary.DNS.CompletedAt.IsZero() {
t.Fatal("expected dns completion timestamp")
}
if len(summary.Connect) == 0 {
t.Fatal("expected connect attempts")
}
connect := summary.Connect[0]
if connect.Network == "" || connect.Addr == "" {
t.Fatalf("unexpected connect summary: %+v", connect)
}
if connect.CompletedAt.IsZero() {
t.Fatalf("expected connect completion timestamp: %+v", connect)
}
}
func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if _, err := firstResp.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
if err := firstResp.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
recorder := NewTraceRecorder()
secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer secondResp.Close()
if _, err := secondResp.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if !summary.Conn.Reused {
t.Fatalf("expected reused connection summary, got %+v", summary.Conn)
}
if summary.TLS == nil || summary.TLS.Version == 0 {
t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS)
}
}
func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
wroteRequest := 0
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetTraceHooks(&TraceHooks{
WroteRequest: func(info TraceWroteRequestInfo) {
wroteRequest++
},
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if wroteRequest == 0 {
t.Fatal("expected custom trace hook to run")
}
summary := recorder.Summary()
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected recorder to capture wrote-request event")
}
}
func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) {
recorder := NewTraceRecorder()
hooks := recorder.Hooks()
hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}},
})
hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}},
})
summary := recorder.Summary()
if len(summary.DNSEvents) != 2 {
t.Fatalf("dns events=%d", len(summary.DNSEvents))
}
if summary.DNSEvents[0].Host != "target.example.test" {
t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host)
}
if summary.DNSEvents[1].Host != "proxy.example.test" {
t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host)
}
if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" {
t.Fatalf("last dns summary=%+v", summary.DNS)
}
}
func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "localhost")
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
wantAddr := strings.TrimPrefix(targetURL, "https://")
if startInfo.Network != "tcp" {
t.Fatalf("start network=%q", startInfo.Network)
}
if startInfo.Addr != wantAddr {
t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr)
}
if startInfo.ServerName != "localhost" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" {
t.Fatalf("done info=%+v", doneInfo)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}
func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) {
var captured []string
traceState := newTraceState(&TraceHooks{
WroteHeaderField: func(info TraceWroteHeaderFieldInfo) {
captured = info.Values
},
})
trace := traceState.clientTrace()
values := []string{"a", "b"}
trace.WroteHeaderField("X-Test", values)
values[0] = "mutated"
if len(captured) != 2 {
t.Fatalf("captured=%v", captured)
}
if captured[0] != "a" {
t.Fatalf("captured=%v", captured)
}
}
func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
recorder := NewTraceRecorder()
client := NewClientNoErr(WithTraceRecorder(recorder))
req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet)
resp1, err := req1.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req2 := req1.Clone().SetURL(server.URL + "/two")
resp2, err := req2.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("req1 trace summary=%+v", got)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("req2 trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("resp2 trace summary=%+v", got)
}
if got := recorder.Summary(); got.URL != server.URL+"/two" {
t.Fatalf("shared recorder summary=%+v", got)
}
}
func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
req := NewSimpleRequest(server.URL+"/first", http.MethodGet).
SetTraceRecorder(NewTraceRecorder())
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req.SetURL(server.URL + "/second")
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("resp2 trace summary=%+v", got)
}
}
func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "trace-custom.example.test")
serverAddr := server.Listener.Addr().String()
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr)
}).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if startInfo.Network != "" || startInfo.Addr != "" {
t.Fatalf("start info=%+v", startInfo)
}
if doneInfo.Network != "" || doneInfo.Addr != "" {
t.Fatalf("done info=%+v", doneInfo)
}
if startInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("done server name=%q", doneInfo.ServerName)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}