starnet/retry.go
starainrt 732e81316c
fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界
- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
2026-04-19 15:39:51 +08:00

467 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"strings"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
// WithRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
// SetRetry 为请求启用自动重试。
// 默认只重试幂等方法即使显式关闭幂等限制Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := normalizeContext(r.ctx)
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptNo := attempt + 1
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
})
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
StatusCode: statusCode,
Err: err,
WillRetry: willRetry,
})
if !willRetry {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
Attempt: attemptNo,
Delay: delay,
})
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
if err != nil {
return nil, err
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
return isReplayableConfiguredBody(r.config.Body)
}
func hasReaderRequestBody(r *Request) bool {
if r == nil || r.config == nil {
return false
}
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
}
func isReplayableConfiguredBody(body BodyConfig) bool {
switch body.Mode {
case bodyModeReader:
return isReplayableBodyReader(body.Reader)
case bodyModeMultipart:
for _, file := range body.Files {
if file.FileData != nil || file.FilePath == "" {
return false
}
}
}
return true
}
func isReplayableBodyReader(reader io.Reader) bool {
switch reader.(type) {
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
return true
default:
return false
}
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}