starnet/retry.go
starainrt 2f4c7158cf
feat: 增加请求级 trace 摘要与诊断能力
- 新增 TraceRecorder 和 TraceSummary,汇总 DNS、连接、TLS、写请求、首包等关键事件
  - 为请求执行链接入结构化 trace hooks,补充标准路径与动态路径的 TLS 元信息
  - 增加 Request.TraceSummary() 和 Response.TraceSummary(),提供请求级与响应级摘要快照
  - 修复共享 TraceRecorder 在 Client 默认选项、Clone 和请求复用场景下的状态串扰问题
  - 修复 Response.TraceSummary() 回读 Request 最近状态导致的非快照语义
  - 收口自定义 DialFunc 下的 TLS trace 元数据,避免伪造连接地址
  - 补充 trace 相关回归测试,覆盖 HTTPS、DNS/Connect、连接复用、共享 recorder、响应快照和自定义拨号场景
  - 更新 README,补充 trace、Host 与 TLSServerName 的行为说明
2026-04-20 17:54:43 +08:00

469 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"strings"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
// WithRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
// SetRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := normalizeContext(r.ctx)
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptNo := attempt + 1
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
})
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
StatusCode: statusCode,
Err: err,
WillRetry: willRetry,
})
if !willRetry {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
Attempt: attemptNo,
Delay: delay,
})
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
attempt.traceRun = r.traceRun
attempt.lastTraceSummary = nil
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
if err != nil {
return nil, err
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
return isReplayableConfiguredBody(r.config.Body)
}
func hasReaderRequestBody(r *Request) bool {
if r == nil || r.config == nil {
return false
}
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
}
func isReplayableConfiguredBody(body BodyConfig) bool {
switch body.Mode {
case bodyModeReader:
return isReplayableBodyReader(body.Reader)
case bodyModeMultipart:
for _, file := range body.Files {
if file.FileData != nil || file.FilePath == "" {
return false
}
}
}
return true
}
func isReplayableBodyReader(reader io.Reader) bool {
switch reader.(type) {
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
return true
default:
return false
}
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}