From 2f4c7158cfcf6e97c8f042ca51750f20ab0b705e Mon Sep 17 00:00:00 2001 From: starainrt Date: Mon, 20 Apr 2026 17:54:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E7=BA=A7=20trace=20=E6=91=98=E8=A6=81=E4=B8=8E=E8=AF=8A?= =?UTF-8?q?=E6=96=AD=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 TraceRecorder 和 TraceSummary,汇总 DNS、连接、TLS、写请求、首包等关键事件 - 为请求执行链接入结构化 trace hooks,补充标准路径与动态路径的 TLS 元信息 - 增加 Request.TraceSummary() 和 Response.TraceSummary(),提供请求级与响应级摘要快照 - 修复共享 TraceRecorder 在 Client 默认选项、Clone 和请求复用场景下的状态串扰问题 - 修复 Response.TraceSummary() 回读 Request 最近状态导致的非快照语义 - 收口自定义 DialFunc 下的 TLS trace 元数据,避免伪造连接地址 - 补充 trace 相关回归测试,覆盖 HTTPS、DNS/Connect、连接复用、共享 recorder、响应快照和自定义拨号场景 - 更新 README,补充 trace、Host 与 TLSServerName 的行为说明 --- README.md | 11 + options_config.go | 5 + request.go | 62 +++-- request_mutation.go | 7 + request_prepare.go | 36 ++- request_trace.go | 50 ++++ response.go | 18 +- retry.go | 2 + trace.go | 274 ++++++++++++++++++++- trace_summary.go | 550 ++++++++++++++++++++++++++++++++++++++++++ trace_summary_test.go | 488 +++++++++++++++++++++++++++++++++++++ 11 files changed, 1471 insertions(+), 32 deletions(-) create mode 100644 trace_summary.go create mode 100644 trace_summary_test.go diff --git a/README.md b/README.md index d4bca50..996fc9a 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ - 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时 - 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置 +- 支持请求级 `Host` 覆盖、显式 `TLSServerName/SNI` 控制,以及结构化 trace 回调 / 摘要 - 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定 - 响应体大小限制,避免一次性读取过大内容 - 错误分类辅助:`ClassifyError`、`IsTimeout`、`IsDNS`、`IsTLS`、`IsProxy`、`IsCanceled` @@ -19,9 +20,17 @@ - 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口 - 支持 `Get`、`Post`、`Put`、`Delete`、`Head`、`Patch`、`Options`、`Trace`、`Connect` - 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态 +- 支持显式 `Host` 覆盖与 `TLSServerName` 设置,便于直连 IP、虚拟主机和证书校验场景分离控制 - Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险 - `Request.Clone()` 可用于并发场景或同一基础请求的变体构造 +### Trace 与诊断 + +- 支持 `TraceHooks`,可接收 DNS、建连、TLS 握手、写请求、首包等结构化事件 +- 支持 `TraceRecorder` / `TraceSummary`,用于汇总一次请求的关键网络过程和 TLS 摘要 +- `Request.TraceSummary()` 返回该请求最近一次执行的摘要快照,`Response.TraceSummary()` 返回当前响应对应的摘要快照 +- 若多个请求共享同一个 `TraceRecorder`,其 `Summary()` 表示最近一次完成请求的摘要 + ### 超时与重试 - 请求超时通过 `context` 截止时间控制,不污染共享客户端配置 @@ -98,8 +107,10 @@ func main() { - `NewClient`、`NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。 - `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`。 +- `SetHost` / `WithHost` 只覆盖 HTTP 请求的 `Host`;如需单独控制 TLS SNI 或证书校验名,应配合 `SetTLSServerName` / `WithTLSServerName` 使用。 - 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader` 或 `WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。 - 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。 +- 绑定到请求上的 `TraceRecorder` 用于发布已完成请求的摘要;请求执行中的中间状态不保证通过共享 recorder 实时可见。 ## 稳定性说明 diff --git a/options_config.go b/options_config.go index c0941c2..8adcdf6 100644 --- a/options_config.go +++ b/options_config.go @@ -46,6 +46,11 @@ func WithTraceHooks(hooks *TraceHooks) RequestOpt { return requestOptFromMutation(mutateTraceHooks(hooks)) } +// WithTraceRecorder 设置请求级 trace 摘要记录器。 +func WithTraceRecorder(recorder *TraceRecorder) RequestOpt { + return requestOptFromMutation(mutateTraceRecorder(recorder)) +} + // WithSkipTLSVerify 设置是否跳过 TLS 验证 func WithSkipTLSVerify(skip bool) RequestOpt { return requestOptFromMutation(mutateSkipTLSVerify(skip)) diff --git a/request.go b/request.go index a4a25e7..a3c6ca9 100644 --- a/request.go +++ b/request.go @@ -17,13 +17,16 @@ type Request struct { method string err error // 累积的错误 - config *RequestConfig - client *Client - httpClient *http.Client - httpReq *http.Request - retry *retryPolicy - traceHooks *TraceHooks - traceState *traceState + config *RequestConfig + client *Client + httpClient *http.Client + httpReq *http.Request + retry *retryPolicy + traceHooks *TraceHooks + traceRecorder *TraceRecorder + traceRun *TraceRecorder + lastTraceSummary *TraceSummary + traceState *traceState applied bool // 是否已应用配置 doRaw bool // 是否使用原始请求(不修改) @@ -81,6 +84,7 @@ func (r *Request) invalidatePreparedState() { r.cancel = nil } r.execCtx = nil + r.traceRun = nil r.traceState = nil r.httpClient = nil @@ -273,18 +277,19 @@ func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts . // Clone 克隆请求 func (r *Request) Clone() *Request { cloned := &Request{ - ctx: r.ctx, - url: r.url, - method: r.method, - err: r.err, - config: r.config.Clone(), - client: r.client, - httpClient: r.httpClient, - retry: cloneRetryPolicy(r.retry), - traceHooks: r.traceHooks, - applied: false, // 重置应用状态 - doRaw: r.doRaw, - autoFetch: r.autoFetch, + ctx: r.ctx, + url: r.url, + method: r.method, + err: r.err, + config: r.config.Clone(), + client: r.client, + httpClient: r.httpClient, + retry: cloneRetryPolicy(r.retry), + traceHooks: r.traceHooks, + traceRecorder: r.traceRecorder, + applied: false, // 重置应用状态 + doRaw: r.doRaw, + autoFetch: r.autoFetch, rawSourceExternal: r.rawSourceExternal, } @@ -477,11 +482,20 @@ func (r *Request) Do() (*Response, error) { return nil, r.err } + r.startTraceExecution() + + var ( + resp *Response + err error + ) if r.hasRetryPolicy() { - return r.doWithRetry() + resp, err = r.doWithRetry() + } else { + resp, err = r.doOnce() } - return r.doOnce() + r.finishTraceExecution(resp) + return resp, err } func (r *Request) doOnce() (*Response, error) { @@ -493,6 +507,9 @@ func (r *Request) doOnce() (*Response, error) { if err := r.prepare(); err != nil { return nil, wrapError(err, "prepare request") } + if r.traceRun != nil { + r.traceRun.observePreparedRequest(r.httpReq) + } // 执行请求 httpResp, err := r.httpClient.Do(r.httpReq) @@ -508,6 +525,9 @@ func (r *Request) doOnce() (*Response, error) { body: &Body{}, }, wrapError(err, "do request") } + if r.traceRun != nil { + r.traceRun.observeResponse(httpResp) + } rawBody := httpResp.Body if r.cancel != nil { diff --git a/request_mutation.go b/request_mutation.go index 2e33e26..b6300a8 100644 --- a/request_mutation.go +++ b/request_mutation.go @@ -122,6 +122,13 @@ func mutateTraceHooks(hooks *TraceHooks) requestMutation { } } +func mutateTraceRecorder(recorder *TraceRecorder) requestMutation { + return func(r *Request) error { + r.traceRecorder = recorder + return nil + } +} + func mutateSkipTLSVerify(skip bool) requestMutation { return func(r *Request) error { r.config.TLS.SkipVerify = skip diff --git a/request_prepare.go b/request_prepare.go index efcb7a5..0889c02 100644 --- a/request_prepare.go +++ b/request_prepare.go @@ -194,6 +194,36 @@ func (r *Request) applyBody(execCtx context.Context) error { return nil } +func buildTraceTLSHandshakeInfo(req *http.Request, execCtx context.Context, defaultServerName string) TraceTLSHandshakeStartInfo { + if req == nil || req.URL == nil || req.URL.Scheme != "https" { + return TraceTLSHandshakeStartInfo{} + } + + reqCtx := getRequestContext(execCtx) + info := TraceTLSHandshakeStartInfo{} + // 自定义 DialFunc 的真实落点由调用方决定,这里只在默认拨号路径下预填地址,避免 trace 元信息误导。 + if reqCtx == nil || reqCtx.DialFn == nil { + info.Network = "tcp" + info.Addr = req.URL.Host + } + if reqCtx != nil { + if reqCtx.TLSConfig != nil && reqCtx.TLSConfig.ServerName != "" { + info.ServerName = reqCtx.TLSConfig.ServerName + } else if reqCtx.TLSServerName != "" { + info.ServerName = reqCtx.TLSServerName + } + } + if info.ServerName == "" { + if defaultServerName != "" { + info.ServerName = defaultServerName + } else { + info.ServerName = req.URL.Hostname() + } + } + + return info +} + // prepare 准备请求(应用配置) func (r *Request) prepare() (err error) { if r.applied { @@ -215,8 +245,10 @@ func (r *Request) prepare() (err error) { execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName) var traceState *traceState - if r.traceHooks != nil { - traceState = newTraceState(r.traceHooks) + traceHooks := composeTraceHooks(r.traceHooks, traceRecorderHooks(r.traceRun)) + if traceHooks != nil { + traceState = newTraceState(traceHooks) + traceState.setDefaultTLSHandshakeInfo(buildTraceTLSHandshakeInfo(r.httpReq, execCtx, defaultTLSServerName)) execCtx = withTraceState(execCtx, traceState) if clientTrace := traceState.clientTrace(); clientTrace != nil { execCtx = httptrace.WithClientTrace(execCtx, clientTrace) diff --git a/request_trace.go b/request_trace.go index f052e28..9533c79 100644 --- a/request_trace.go +++ b/request_trace.go @@ -4,3 +4,53 @@ package starnet func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request { return r.applyMutation(mutateTraceHooks(hooks)) } + +// SetTraceRecorder 设置请求级 trace 摘要记录器。 +// 记录器会保存最近一次已完成请求的摘要;若多个请求共享同一个记录器,则以最后一次完成的请求为准。 +func (r *Request) SetTraceRecorder(recorder *TraceRecorder) *Request { + return r.applyMutation(mutateTraceRecorder(recorder)) +} + +// TraceSummary 返回当前请求最近一次执行的 trace 摘要快照。 +func (r *Request) TraceSummary() *TraceSummary { + if r == nil || r.lastTraceSummary == nil { + return nil + } + summary := cloneTraceSummary(*r.lastTraceSummary) + return &summary +} + +func (r *Request) startTraceExecution() { + if r == nil { + return + } + r.traceRun = nil + if r.traceRecorder == nil { + return + } + r.traceRun = r.traceRecorder.forkExecution() + if r.traceRun != nil { + r.traceRun.startRequest() + } +} + +func (r *Request) finishTraceExecution(resp *Response) { + if r == nil { + return + } + if r.traceRun == nil { + r.lastTraceSummary = nil + if resp != nil { + resp.traceSummary = nil + } + return + } + + summary := r.traceRun.Summary() + r.lastTraceSummary = cloneTraceSummaryPtr(summary) + r.traceRecorder.publishSummary(summary) + if resp != nil { + resp.traceSummary = cloneTraceSummaryPtr(summary) + } + r.traceRun = nil +} diff --git a/response.go b/response.go index 3942f7e..5d02d04 100644 --- a/response.go +++ b/response.go @@ -11,10 +11,11 @@ import ( // Response HTTP 响应 type Response struct { *http.Response - request *Request - httpClient *http.Client - cancel func() - body *Body + request *Request + httpClient *http.Client + cancel func() + body *Body + traceSummary *TraceSummary } // Body 响应体 @@ -47,6 +48,15 @@ func (r *Response) Request() *Request { return r.request } +// TraceSummary 获取当前响应对应的 trace 摘要快照。 +func (r *Response) TraceSummary() *TraceSummary { + if r == nil || r.traceSummary == nil { + return nil + } + summary := cloneTraceSummary(*r.traceSummary) + return &summary +} + // Body 获取响应体 func (r *Response) Body() *Body { return r.body diff --git a/retry.go b/retry.go index f63628b..b9a808a 100644 --- a/retry.go +++ b/retry.go @@ -312,6 +312,8 @@ func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) { attempt.applied = false attempt.execCtx = nil attempt.ctx = ctx + attempt.traceRun = r.traceRun + attempt.lastTraceSummary = nil // 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。 if attempt.config != nil && attempt.config.Network.Timeout > 0 { diff --git a/trace.go b/trace.go index 7b1153d..194ad28 100644 --- a/trace.go +++ b/trace.go @@ -110,9 +110,10 @@ type TraceRetryBackoffInfo struct { } type traceState struct { - hooks *TraceHooks - customTLS atomic.Uint32 - manualDNSRefs atomic.Int32 + hooks *TraceHooks + customTLS atomic.Uint32 + manualDNSRefs atomic.Int32 + defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo } func newTraceState(hooks *TraceHooks) *traceState { @@ -122,6 +123,20 @@ func newTraceState(hooks *TraceHooks) *traceState { return &traceState{hooks: hooks} } +func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) { + if t == nil { + return + } + t.defaultTLSHandshakeInfo = info +} + +func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo { + if t == nil { + return TraceTLSHandshakeStartInfo{} + } + return t.defaultTLSHandshakeInfo +} + func withTraceState(ctx context.Context, state *traceState) context.Context { if state == nil { return ctx @@ -219,7 +234,7 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace { if t.usesCustomTLS() { return } - h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{}) + h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo()) } } if h.TLSHandshakeDone != nil { @@ -227,7 +242,11 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace { if t.usesCustomTLS() { return } + info := t.getDefaultTLSHandshakeInfo() h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{ + Network: info.Network, + Addr: info.Addr, + ServerName: info.ServerName, ConnectionState: state, Err: err, }) @@ -237,7 +256,7 @@ func (t *traceState) clientTrace() *httptrace.ClientTrace { trace.WroteHeaderField = func(key string, value []string) { h.WroteHeaderField(TraceWroteHeaderFieldInfo{ Key: key, - Values: value, + Values: append([]string(nil), value...), }) } } @@ -338,3 +357,248 @@ func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) { } hooks.RetryBackoff(info) } + +func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks { + if recorder == nil { + return nil + } + return recorder.Hooks() +} + +func composeTraceHooks(first, second *TraceHooks) *TraceHooks { + switch { + case first == nil: + return second + case second == nil: + return first + } + + return &TraceHooks{ + GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn), + GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn), + PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn), + DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart), + DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone), + ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart), + ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone), + TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart), + TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone), + WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField), + WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders), + WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest), + GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte), + RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart), + RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone), + RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff), + } +} + +func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceGetConnInfo) { + first(info) + second(info) + } + } +} + +func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceGotConnInfo) { + first(info) + second(info) + } + } +} + +func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TracePutIdleConnInfo) { + first(info) + second(info) + } + } +} + +func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceDNSStartInfo) { + first(info) + second(info) + } + } +} + +func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceDNSDoneInfo) { + first(info) + second(info) + } + } +} + +func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceConnectStartInfo) { + first(info) + second(info) + } + } +} + +func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceConnectDoneInfo) { + first(info) + second(info) + } + } +} + +func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceTLSHandshakeStartInfo) { + first(info) + second(info) + } + } +} + +func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceTLSHandshakeDoneInfo) { + first(info) + second(info) + } + } +} + +func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceWroteHeaderFieldInfo) { + first(info) + second(info) + } + } +} + +func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceWroteRequestInfo) { + first(info) + second(info) + } + } +} + +func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceRetryAttemptStartInfo) { + first(info) + second(info) + } + } +} + +func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceRetryAttemptDoneInfo) { + first(info) + second(info) + } + } +} + +func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func(info TraceRetryBackoffInfo) { + first(info) + second(info) + } + } +} + +func composeTraceSimpleHook(first, second func()) func() { + switch { + case first == nil: + return second + case second == nil: + return first + default: + return func() { + first() + second() + } + } +} diff --git a/trace_summary.go b/trace_summary.go new file mode 100644 index 0000000..eec9ece --- /dev/null +++ b/trace_summary.go @@ -0,0 +1,550 @@ +package starnet + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "time" +) + +// TraceSummary 是一次请求执行的 trace 摘要。 +type TraceSummary struct { + Method string + URL string + StartedAt time.Time + ResponseAt time.Time + StatusCode int + ResponseProto string + Conn TraceConnSummary + DNS *TraceDNSSummary + DNSEvents []TraceDNSSummary + Connect []TraceConnectSummary + TLS *TraceTLSSummary + RequestWrittenAt time.Time + RequestWriteErr error + FirstResponseByteAt time.Time +} + +// TraceConnSummary 是连接复用与套接字信息摘要。 +type TraceConnSummary struct { + Addr string + LocalAddr string + RemoteAddr string + Reused bool + WasIdle bool + IdleTime time.Duration +} + +// TraceDNSSummary 是 DNS 解析摘要。 +type TraceDNSSummary struct { + Host string + Addrs []string + Coalesced bool + StartedAt time.Time + CompletedAt time.Time + Duration time.Duration + Err error +} + +// TraceConnectSummary 是单次建连尝试摘要。 +type TraceConnectSummary struct { + Network string + Addr string + StartedAt time.Time + CompletedAt time.Time + Duration time.Duration + Err error +} + +// TraceTLSSummary 是 TLS 握手与连接状态摘要。 +type TraceTLSSummary struct { + Network string + Addr string + ServerName string + Version uint16 + VersionName string + CipherSuite uint16 + CipherSuiteName string + CurveID tls.CurveID + CurveName string + NegotiatedProtocol string + DidResume bool + ECHAccepted bool + VerifiedChains int + StartedAt time.Time + CompletedAt time.Time + Duration time.Duration + Err error + PeerCertificates []TraceCertificateSummary +} + +// TraceCertificateSummary 是单张证书的关键信息摘要。 +type TraceCertificateSummary struct { + Subject string + Issuer string + DNSNames []string + IPAddresses []string +} + +// TraceRecorder 聚合最近一次发布的 trace 摘要。 +// 通过 Request/Client 绑定时,starnet 会为每次执行创建私有运行态并在完成后发布摘要; +// 直接使用 Hooks() 时,调用方仍需自行管理 Reset 与生命周期。 +type TraceRecorder struct { + mu sync.Mutex + + summary TraceSummary + pendingDNS []TraceDNSSummary + pendingConnectStarts map[string][]time.Time + pendingTLSStart time.Time + hooks *TraceHooks +} + +// NewTraceRecorder 创建请求级 trace 记录器。 +func NewTraceRecorder() *TraceRecorder { + recorder := &TraceRecorder{} + recorder.hooks = &TraceHooks{ + GetConn: recorder.onGetConn, + GotConn: recorder.onGotConn, + DNSStart: recorder.onDNSStart, + DNSDone: recorder.onDNSDone, + ConnectStart: recorder.onConnectStart, + ConnectDone: recorder.onConnectDone, + TLSHandshakeStart: recorder.onTLSHandshakeStart, + TLSHandshakeDone: recorder.onTLSHandshakeDone, + WroteRequest: recorder.onWroteRequest, + GotFirstResponseByte: recorder.onGotFirstResponseByte, + } + return recorder +} + +// Hooks 返回可挂到请求上的底层 trace hooks。 +func (r *TraceRecorder) Hooks() *TraceHooks { + if r == nil { + return nil + } + if r.hooks == nil { + r.hooks = &TraceHooks{ + GetConn: r.onGetConn, + GotConn: r.onGotConn, + DNSStart: r.onDNSStart, + DNSDone: r.onDNSDone, + ConnectStart: r.onConnectStart, + ConnectDone: r.onConnectDone, + TLSHandshakeStart: r.onTLSHandshakeStart, + TLSHandshakeDone: r.onTLSHandshakeDone, + WroteRequest: r.onWroteRequest, + GotFirstResponseByte: r.onGotFirstResponseByte, + } + } + return r.hooks +} + +// Reset 清空当前摘要和内部状态。 +func (r *TraceRecorder) Reset() { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.resetLocked() +} + +// Summary 返回当前 trace 摘要的快照。 +func (r *TraceRecorder) Summary() TraceSummary { + if r == nil { + return TraceSummary{} + } + r.mu.Lock() + defer r.mu.Unlock() + return cloneTraceSummary(r.summary) +} + +func (r *TraceRecorder) forkExecution() *TraceRecorder { + if r == nil { + return nil + } + return NewTraceRecorder() +} + +func (r *TraceRecorder) publishSummary(summary TraceSummary) { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.summary = cloneTraceSummary(summary) +} + +func (r *TraceRecorder) startRequest() { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.resetLocked() + r.summary.StartedAt = time.Now() +} + +func (r *TraceRecorder) observePreparedRequest(req *http.Request) { + if r == nil || req == nil || req.URL == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.ensureStartedLocked(time.Now()) + r.summary.Method = req.Method + r.summary.URL = req.URL.String() +} + +func (r *TraceRecorder) observeResponse(resp *http.Response) { + if r == nil || resp == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + r.ensureStartedLocked(now) + r.summary.ResponseAt = now + r.summary.StatusCode = resp.StatusCode + r.summary.ResponseProto = resp.Proto + if resp.TLS != nil { + r.summary.TLS = mergeTraceTLSSummary(r.summary.TLS, *resp.TLS) + } +} + +func (r *TraceRecorder) ensureStartedLocked(now time.Time) { + if r.summary.StartedAt.IsZero() { + r.summary.StartedAt = now + } +} + +func (r *TraceRecorder) resetLocked() { + r.summary = TraceSummary{} + r.pendingDNS = nil + r.pendingTLSStart = time.Time{} + if len(r.pendingConnectStarts) == 0 { + r.pendingConnectStarts = nil + return + } + for key := range r.pendingConnectStarts { + delete(r.pendingConnectStarts, key) + } + r.pendingConnectStarts = nil +} + +func (r *TraceRecorder) onGetConn(info TraceGetConnInfo) { + r.mu.Lock() + defer r.mu.Unlock() + r.ensureStartedLocked(time.Now()) + r.summary.Conn.Addr = info.Addr +} + +func (r *TraceRecorder) onGotConn(info TraceGotConnInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + r.summary.Conn.Reused = info.Reused + r.summary.Conn.WasIdle = info.WasIdle + r.summary.Conn.IdleTime = info.IdleTime + if info.Conn != nil { + r.summary.Conn.LocalAddr = traceAddrString(info.Conn.LocalAddr()) + r.summary.Conn.RemoteAddr = traceAddrString(info.Conn.RemoteAddr()) + } +} + +func (r *TraceRecorder) onDNSStart(info TraceDNSStartInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + dns := TraceDNSSummary{ + Host: info.Host, + StartedAt: now, + } + r.pendingDNS = append(r.pendingDNS, dns) + copyDNS := dns + r.summary.DNS = ©DNS +} + +func (r *TraceRecorder) onDNSDone(info TraceDNSDoneInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + + dns := TraceDNSSummary{ + Host: "", + Addrs: traceIPAddrsToStrings(info.Addrs), + Coalesced: info.Coalesced, + CompletedAt: now, + Err: info.Err, + } + if len(r.pendingDNS) > 0 { + dns.Host = r.pendingDNS[0].Host + dns.StartedAt = r.pendingDNS[0].StartedAt + if len(r.pendingDNS) == 1 { + r.pendingDNS = nil + } else { + r.pendingDNS = append([]TraceDNSSummary(nil), r.pendingDNS[1:]...) + } + } else if r.summary.DNS != nil { + dns.Host = r.summary.DNS.Host + dns.StartedAt = r.summary.DNS.StartedAt + } + if !dns.StartedAt.IsZero() { + dns.Duration = now.Sub(dns.StartedAt) + } + r.summary.DNSEvents = append(r.summary.DNSEvents, dns) + copyDNS := dns + r.summary.DNS = ©DNS +} + +func (r *TraceRecorder) onConnectStart(info TraceConnectStartInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + if r.pendingConnectStarts == nil { + r.pendingConnectStarts = make(map[string][]time.Time) + } + key := traceConnectKey(info.Network, info.Addr) + r.pendingConnectStarts[key] = append(r.pendingConnectStarts[key], now) + r.summary.Connect = append(r.summary.Connect, TraceConnectSummary{ + Network: info.Network, + Addr: info.Addr, + StartedAt: now, + }) +} + +func (r *TraceRecorder) onConnectDone(info TraceConnectDoneInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + + start := time.Time{} + key := traceConnectKey(info.Network, info.Addr) + if starts := r.pendingConnectStarts[key]; len(starts) > 0 { + start = starts[0] + if len(starts) == 1 { + delete(r.pendingConnectStarts, key) + } else { + r.pendingConnectStarts[key] = starts[1:] + } + } + + connect := TraceConnectSummary{ + Network: info.Network, + Addr: info.Addr, + StartedAt: start, + CompletedAt: now, + Err: info.Err, + } + if !start.IsZero() { + connect.Duration = now.Sub(start) + } + + for index := len(r.summary.Connect) - 1; index >= 0; index-- { + item := &r.summary.Connect[index] + if item.Network != info.Network || item.Addr != info.Addr || !item.CompletedAt.IsZero() { + continue + } + item.CompletedAt = now + item.Duration = connect.Duration + item.Err = info.Err + return + } + r.summary.Connect = append(r.summary.Connect, connect) +} + +func (r *TraceRecorder) onTLSHandshakeStart(info TraceTLSHandshakeStartInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + r.pendingTLSStart = now + r.summary.TLS = &TraceTLSSummary{ + Network: info.Network, + Addr: info.Addr, + ServerName: info.ServerName, + StartedAt: now, + } +} + +func (r *TraceRecorder) onTLSHandshakeDone(info TraceTLSHandshakeDoneInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + + var tlsSummary *TraceTLSSummary + if r.summary.TLS != nil { + copied := *r.summary.TLS + tlsSummary = &copied + } else { + tlsSummary = &TraceTLSSummary{} + } + if tlsSummary.Network == "" { + tlsSummary.Network = info.Network + } + if tlsSummary.Addr == "" { + tlsSummary.Addr = info.Addr + } + if tlsSummary.ServerName == "" { + tlsSummary.ServerName = info.ServerName + } + if tlsSummary.StartedAt.IsZero() { + tlsSummary.StartedAt = r.pendingTLSStart + } + tlsSummary.CompletedAt = now + if !tlsSummary.StartedAt.IsZero() { + tlsSummary.Duration = now.Sub(tlsSummary.StartedAt) + } + tlsSummary.Err = info.Err + tlsSummary = mergeTraceTLSSummary(tlsSummary, info.ConnectionState) + if tlsSummary.ServerName == "" { + tlsSummary.ServerName = info.ServerName + } + if tlsSummary.Addr == "" { + tlsSummary.Addr = info.Addr + } + if tlsSummary.Network == "" { + tlsSummary.Network = info.Network + } + r.pendingTLSStart = time.Time{} + r.summary.TLS = tlsSummary +} + +func (r *TraceRecorder) onWroteRequest(info TraceWroteRequestInfo) { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + r.summary.RequestWrittenAt = now + r.summary.RequestWriteErr = info.Err +} + +func (r *TraceRecorder) onGotFirstResponseByte() { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + r.ensureStartedLocked(now) + r.summary.FirstResponseByteAt = now +} + +func traceAddrString(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.String() +} + +func traceConnectKey(network, addr string) string { + return network + "\x00" + addr +} + +func traceIPAddrsToStrings(addrs []net.IPAddr) []string { + if len(addrs) == 0 { + return nil + } + out := make([]string, 0, len(addrs)) + for _, addr := range addrs { + out = append(out, addr.String()) + } + return out +} + +func mergeTraceTLSSummary(summary *TraceTLSSummary, state tls.ConnectionState) *TraceTLSSummary { + if summary == nil { + summary = &TraceTLSSummary{} + } + if state.Version != 0 { + summary.Version = state.Version + summary.VersionName = tls.VersionName(state.Version) + } + if state.CipherSuite != 0 { + summary.CipherSuite = state.CipherSuite + summary.CipherSuiteName = tls.CipherSuiteName(state.CipherSuite) + } + if state.CurveID != 0 { + summary.CurveID = state.CurveID + summary.CurveName = state.CurveID.String() + } + if state.ServerName != "" { + summary.ServerName = state.ServerName + } + if state.NegotiatedProtocol != "" { + summary.NegotiatedProtocol = state.NegotiatedProtocol + } + summary.DidResume = state.DidResume + summary.ECHAccepted = state.ECHAccepted + summary.VerifiedChains = len(state.VerifiedChains) + if len(state.PeerCertificates) > 0 { + summary.PeerCertificates = make([]TraceCertificateSummary, 0, len(state.PeerCertificates)) + for _, cert := range state.PeerCertificates { + certSummary := TraceCertificateSummary{ + Subject: cert.Subject.String(), + Issuer: cert.Issuer.String(), + DNSNames: append([]string(nil), cert.DNSNames...), + } + if len(cert.IPAddresses) > 0 { + certSummary.IPAddresses = make([]string, 0, len(cert.IPAddresses)) + for _, ip := range cert.IPAddresses { + certSummary.IPAddresses = append(certSummary.IPAddresses, ip.String()) + } + } + summary.PeerCertificates = append(summary.PeerCertificates, certSummary) + } + } + return summary +} + +func cloneTraceSummary(summary TraceSummary) TraceSummary { + cloned := summary + if summary.DNS != nil { + dns := *summary.DNS + dns.Addrs = append([]string(nil), summary.DNS.Addrs...) + cloned.DNS = &dns + } + if len(summary.DNSEvents) > 0 { + cloned.DNSEvents = make([]TraceDNSSummary, 0, len(summary.DNSEvents)) + for _, dns := range summary.DNSEvents { + cloned.DNSEvents = append(cloned.DNSEvents, TraceDNSSummary{ + Host: dns.Host, + Addrs: append([]string(nil), dns.Addrs...), + Coalesced: dns.Coalesced, + StartedAt: dns.StartedAt, + CompletedAt: dns.CompletedAt, + Duration: dns.Duration, + Err: dns.Err, + }) + } + } + if len(summary.Connect) > 0 { + cloned.Connect = append([]TraceConnectSummary(nil), summary.Connect...) + } + if summary.TLS != nil { + tlsSummary := *summary.TLS + if len(summary.TLS.PeerCertificates) > 0 { + tlsSummary.PeerCertificates = make([]TraceCertificateSummary, 0, len(summary.TLS.PeerCertificates)) + for _, cert := range summary.TLS.PeerCertificates { + tlsSummary.PeerCertificates = append(tlsSummary.PeerCertificates, TraceCertificateSummary{ + Subject: cert.Subject, + Issuer: cert.Issuer, + DNSNames: append([]string(nil), cert.DNSNames...), + IPAddresses: append([]string(nil), cert.IPAddresses...), + }) + } + } + cloned.TLS = &tlsSummary + } + return cloned +} + +func cloneTraceSummaryPtr(summary TraceSummary) *TraceSummary { + cloned := cloneTraceSummary(summary) + return &cloned +} diff --git a/trace_summary_test.go b/trace_summary_test.go new file mode 100644 index 0000000..e3136df --- /dev/null +++ b/trace_summary_test.go @@ -0,0 +1,488 @@ +package starnet + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "strconv" + "strings" + "testing" +) + +func TestTraceRecorderCapturesHTTPSummary(t *testing.T) { + server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + recorder := NewTraceRecorder() + req := NewSimpleRequest(server.URL, http.MethodGet). + SetSkipTLSVerify(true). + SetTraceRecorder(recorder) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + summary := recorder.Summary() + if summary.Method != http.MethodGet { + t.Fatalf("method=%q", summary.Method) + } + if summary.URL != server.URL { + t.Fatalf("url=%q", summary.URL) + } + if summary.StatusCode != http.StatusOK { + t.Fatalf("status=%d", summary.StatusCode) + } + if summary.ResponseProto == "" { + t.Fatal("expected response proto") + } + if summary.RequestWrittenAt.IsZero() { + t.Fatal("expected request write timestamp") + } + if summary.FirstResponseByteAt.IsZero() { + t.Fatal("expected first response byte timestamp") + } + if summary.Conn.Addr == "" { + t.Fatal("expected get-conn target address") + } + if summary.TLS == nil { + t.Fatal("expected tls summary") + } + + tlsSummary := summary.TLS + if tlsSummary.Version == 0 || tlsSummary.VersionName == "" { + t.Fatalf("unexpected tls version summary: %+v", tlsSummary) + } + if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" { + t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary) + } + if tlsSummary.ServerName == "" { + t.Fatal("expected tls server name") + } + if resp.TLS == nil { + t.Fatal("expected response TLS state") + } + if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol { + t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol) + } + if len(tlsSummary.PeerCertificates) == 0 { + t.Fatal("expected certificate summaries") + } + + leaf := tlsSummary.PeerCertificates[0] + if leaf.Subject == "" || leaf.Issuer == "" { + t.Fatalf("unexpected leaf certificate summary: %+v", leaf) + } + if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 { + t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf) + } + + if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK { + t.Fatalf("request trace summary=%+v", got) + } + if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK { + t.Fatalf("response trace summary=%+v", got) + } +} + +func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) + if err != nil { + t.Fatalf("ResolveTCPAddr() error: %v", err) + } + + recorder := NewTraceRecorder() + targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port) + + resp, err := NewSimpleRequest(targetURL, http.MethodGet). + SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: addr.IP}}, nil + }). + SetTraceRecorder(recorder). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + summary := recorder.Summary() + if summary.DNS == nil { + t.Fatal("expected dns summary") + } + if summary.DNS.Host != "trace-summary.example.test" { + t.Fatalf("dns host=%q", summary.DNS.Host) + } + if len(summary.DNS.Addrs) == 0 { + t.Fatal("expected resolved addresses") + } + if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) { + t.Fatalf("dns addrs=%v", summary.DNS.Addrs) + } + if summary.DNS.CompletedAt.IsZero() { + t.Fatal("expected dns completion timestamp") + } + if len(summary.Connect) == 0 { + t.Fatal("expected connect attempts") + } + + connect := summary.Connect[0] + if connect.Network == "" || connect.Addr == "" { + t.Fatalf("unexpected connect summary: %+v", connect) + } + if connect.CompletedAt.IsZero() { + t.Fatalf("expected connect completion timestamp: %+v", connect) + } +} + +func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) { + server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + client := NewClientNoErr() + + firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet). + SetSkipTLSVerify(true). + Do() + if err != nil { + t.Fatalf("first Do() error: %v", err) + } + if _, err := firstResp.Body().Bytes(); err != nil { + t.Fatalf("first Body().Bytes() error: %v", err) + } + if err := firstResp.Close(); err != nil { + t.Fatalf("first Close() error: %v", err) + } + + recorder := NewTraceRecorder() + secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet). + SetSkipTLSVerify(true). + SetTraceRecorder(recorder). + Do() + if err != nil { + t.Fatalf("second Do() error: %v", err) + } + defer secondResp.Close() + + if _, err := secondResp.Body().Bytes(); err != nil { + t.Fatalf("second Body().Bytes() error: %v", err) + } + + summary := recorder.Summary() + if !summary.Conn.Reused { + t.Fatalf("expected reused connection summary, got %+v", summary.Conn) + } + if summary.TLS == nil || summary.TLS.Version == 0 { + t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS) + } +} + +func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + recorder := NewTraceRecorder() + wroteRequest := 0 + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetTraceHooks(&TraceHooks{ + WroteRequest: func(info TraceWroteRequestInfo) { + wroteRequest++ + }, + }). + SetTraceRecorder(recorder). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + if wroteRequest == 0 { + t.Fatal("expected custom trace hook to run") + } + + summary := recorder.Summary() + if summary.RequestWrittenAt.IsZero() { + t.Fatal("expected recorder to capture wrote-request event") + } +} + +func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) { + recorder := NewTraceRecorder() + hooks := recorder.Hooks() + + hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"}) + hooks.DNSDone(TraceDNSDoneInfo{ + Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, + }) + + hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"}) + hooks.DNSDone(TraceDNSDoneInfo{ + Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}}, + }) + + summary := recorder.Summary() + if len(summary.DNSEvents) != 2 { + t.Fatalf("dns events=%d", len(summary.DNSEvents)) + } + if summary.DNSEvents[0].Host != "target.example.test" { + t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host) + } + if summary.DNSEvents[1].Host != "proxy.example.test" { + t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host) + } + if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" { + t.Fatalf("last dns summary=%+v", summary.DNS) + } +} + +func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + client := NewClientNoErr() + transport, ok := client.HTTPClient().Transport.(*Transport) + if !ok { + t.Fatalf("transport type=%T", client.HTTPClient().Transport) + } + base := newBaseHTTPTransport() + base.TLSClientConfig = &tls.Config{RootCAs: pool} + transport.SetBase(base) + + targetURL := httpsURLForHost(t, server, "localhost") + var startInfo TraceTLSHandshakeStartInfo + var doneInfo TraceTLSHandshakeDoneInfo + + resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). + SetTraceHooks(&TraceHooks{ + TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { + startInfo = info + }, + TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { + doneInfo = info + }, + }). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + wantAddr := strings.TrimPrefix(targetURL, "https://") + if startInfo.Network != "tcp" { + t.Fatalf("start network=%q", startInfo.Network) + } + if startInfo.Addr != wantAddr { + t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr) + } + if startInfo.ServerName != "localhost" { + t.Fatalf("start server name=%q", startInfo.ServerName) + } + if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" { + t.Fatalf("done info=%+v", doneInfo) + } + if doneInfo.ConnectionState.Version == 0 { + t.Fatalf("done state=%+v", doneInfo.ConnectionState) + } +} + +func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) { + var captured []string + traceState := newTraceState(&TraceHooks{ + WroteHeaderField: func(info TraceWroteHeaderFieldInfo) { + captured = info.Values + }, + }) + + trace := traceState.clientTrace() + values := []string{"a", "b"} + trace.WroteHeaderField("X-Test", values) + values[0] = "mutated" + + if len(captured) != 2 { + t.Fatalf("captured=%v", captured) + } + if captured[0] != "a" { + t.Fatalf("captured=%v", captured) + } +} + +func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(r.URL.Path)) + })) + defer server.Close() + + recorder := NewTraceRecorder() + client := NewClientNoErr(WithTraceRecorder(recorder)) + + req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet) + resp1, err := req1.Do() + if err != nil { + t.Fatalf("first Do() error: %v", err) + } + defer resp1.Close() + if _, err := resp1.Body().Bytes(); err != nil { + t.Fatalf("first Body().Bytes() error: %v", err) + } + + req2 := req1.Clone().SetURL(server.URL + "/two") + resp2, err := req2.Do() + if err != nil { + t.Fatalf("second Do() error: %v", err) + } + defer resp2.Close() + if _, err := resp2.Body().Bytes(); err != nil { + t.Fatalf("second Body().Bytes() error: %v", err) + } + + if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" { + t.Fatalf("req1 trace summary=%+v", got) + } + if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" { + t.Fatalf("resp1 trace summary=%+v", got) + } + if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" { + t.Fatalf("req2 trace summary=%+v", got) + } + if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" { + t.Fatalf("resp2 trace summary=%+v", got) + } + if got := recorder.Summary(); got.URL != server.URL+"/two" { + t.Fatalf("shared recorder summary=%+v", got) + } +} + +func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) { + server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(r.URL.Path)) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL+"/first", http.MethodGet). + SetTraceRecorder(NewTraceRecorder()) + + resp1, err := req.Do() + if err != nil { + t.Fatalf("first Do() error: %v", err) + } + defer resp1.Close() + if _, err := resp1.Body().Bytes(); err != nil { + t.Fatalf("first Body().Bytes() error: %v", err) + } + + req.SetURL(server.URL + "/second") + resp2, err := req.Do() + if err != nil { + t.Fatalf("second Do() error: %v", err) + } + defer resp2.Close() + if _, err := resp2.Body().Bytes(); err != nil { + t.Fatalf("second Body().Bytes() error: %v", err) + } + + if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" { + t.Fatalf("resp1 trace summary=%+v", got) + } + if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" { + t.Fatalf("request trace summary=%+v", got) + } + if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" { + t.Fatalf("resp2 trace summary=%+v", got) + } +} + +func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) { + server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + client := NewClientNoErr() + transport, ok := client.HTTPClient().Transport.(*Transport) + if !ok { + t.Fatalf("transport type=%T", client.HTTPClient().Transport) + } + base := newBaseHTTPTransport() + base.TLSClientConfig = &tls.Config{RootCAs: pool} + transport.SetBase(base) + + targetURL := httpsURLForHost(t, server, "trace-custom.example.test") + serverAddr := server.Listener.Addr().String() + + var startInfo TraceTLSHandshakeStartInfo + var doneInfo TraceTLSHandshakeDoneInfo + resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). + SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr) + }). + SetTraceHooks(&TraceHooks{ + TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { + startInfo = info + }, + TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { + doneInfo = info + }, + }). + Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + if _, err := resp.Body().Bytes(); err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + if startInfo.Network != "" || startInfo.Addr != "" { + t.Fatalf("start info=%+v", startInfo) + } + if doneInfo.Network != "" || doneInfo.Addr != "" { + t.Fatalf("done info=%+v", doneInfo) + } + if startInfo.ServerName != "trace-custom.example.test" { + t.Fatalf("start server name=%q", startInfo.ServerName) + } + if doneInfo.ServerName != "trace-custom.example.test" { + t.Fatalf("done server name=%q", doneInfo.ServerName) + } + if doneInfo.ConnectionState.Version == 0 { + t.Fatalf("done state=%+v", doneInfo.ConnectionState) + } +}