package starnet import ( "context" "errors" "fmt" "io" "math" "math/rand" "net" "net/http" "time" ) type RetryOpt func(*retryPolicy) error type retryPolicy struct { maxRetries int baseDelay time.Duration maxDelay time.Duration factor float64 jitter float64 idempotentOnly bool statuses map[int]struct{} onError func(error) bool } func cloneRetryPolicy(p *retryPolicy) *retryPolicy { if p == nil { return nil } cloned := &retryPolicy{ maxRetries: p.maxRetries, baseDelay: p.baseDelay, maxDelay: p.maxDelay, factor: p.factor, jitter: p.jitter, idempotentOnly: p.idempotentOnly, onError: p.onError, } if p.statuses != nil { cloned.statuses = make(map[int]struct{}, len(p.statuses)) for code := range p.statuses { cloned.statuses[code] = struct{}{} } } return cloned } func defaultRetryPolicy(max int) *retryPolicy { return &retryPolicy{ maxRetries: max, baseDelay: 100 * time.Millisecond, maxDelay: 2 * time.Second, factor: 2.0, jitter: 0.1, idempotentOnly: true, statuses: map[int]struct{}{ http.StatusRequestTimeout: {}, http.StatusTooEarly: {}, http.StatusTooManyRequests: {}, http.StatusInternalServerError: {}, http.StatusBadGateway: {}, http.StatusServiceUnavailable: {}, http.StatusGatewayTimeout: {}, }, } } func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) { if max < 0 { return nil, fmt.Errorf("max retry must be >= 0") } if max == 0 { return nil, nil } policy := defaultRetryPolicy(max) for _, opt := range opts { if opt == nil { continue } if err := opt(policy); err != nil { return nil, err } } return policy, nil } func WithRetry(max int, opts ...RetryOpt) RequestOpt { return func(r *Request) error { policy, err := buildRetryPolicy(max, opts...) if err != nil { return err } r.retry = policy return nil } } func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request { if r.err != nil { return r } policy, err := buildRetryPolicy(max, opts...) if err != nil { r.err = err return r } r.retry = policy return r } func (r *Request) DisableRetry() *Request { if r.err != nil { return r } r.retry = nil return r } func (r *Request) applyRetryOpt(opt RetryOpt) *Request { if r.err != nil { return r } if opt == nil { return r } if r.retry == nil { r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first") return r } if err := opt(r.retry); err != nil { r.err = err } return r } func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request { return r.applyRetryOpt(WithRetryBackoff(base, max, factor)) } func (r *Request) SetRetryJitter(ratio float64) *Request { return r.applyRetryOpt(WithRetryJitter(ratio)) } func (r *Request) SetRetryStatuses(codes ...int) *Request { return r.applyRetryOpt(WithRetryStatuses(codes...)) } func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request { return r.applyRetryOpt(WithRetryIdempotentOnly(enabled)) } func (r *Request) SetRetryOnError(fn func(error) bool) *Request { return r.applyRetryOpt(WithRetryOnError(fn)) } func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt { return func(p *retryPolicy) error { if base < 0 { return fmt.Errorf("retry base delay must be >= 0") } if max < 0 { return fmt.Errorf("retry max delay must be >= 0") } if factor <= 0 { return fmt.Errorf("retry factor must be > 0") } p.baseDelay = base p.maxDelay = max p.factor = factor return nil } } func WithRetryJitter(ratio float64) RetryOpt { return func(p *retryPolicy) error { if ratio < 0 || ratio > 1 { return fmt.Errorf("retry jitter ratio must be in [0,1]") } p.jitter = ratio return nil } } func WithRetryStatuses(codes ...int) RetryOpt { return func(p *retryPolicy) error { statuses := make(map[int]struct{}, len(codes)) for _, code := range codes { if code < 100 || code > 999 { return fmt.Errorf("invalid retry status code: %d", code) } statuses[code] = struct{}{} } p.statuses = statuses return nil } } func WithRetryIdempotentOnly(enabled bool) RetryOpt { return func(p *retryPolicy) error { p.idempotentOnly = enabled return nil } } func WithRetryOnError(fn func(error) bool) RetryOpt { return func(p *retryPolicy) error { p.onError = fn return nil } } func (r *Request) hasRetryPolicy() bool { return r.retry != nil && r.retry.maxRetries > 0 } func (r *Request) doWithRetry() (*Response, error) { policy := cloneRetryPolicy(r.retry) if policy == nil || policy.maxRetries <= 0 { return r.doOnce() } if !policy.canRetryRequest(r) { return r.doOnce() } retryCtx := r.ctx retryCancel := func() {} if r.config.Network.Timeout > 0 { retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout) } defer retryCancel() maxAttempts := policy.maxRetries + 1 var lastResp *Response var lastErr error for attempt := 0; attempt < maxAttempts; attempt++ { attemptReq, err := r.newRetryAttempt(retryCtx) if err != nil { return nil, wrapError(err, "build retry attempt") } resp, err := attemptReq.doOnce() if resp != nil { resp.request = r } if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) { return resp, err } lastResp = resp lastErr = err if lastResp != nil { _ = lastResp.Close() } delay := policy.nextDelay(attempt) if delay <= 0 { continue } timer := time.NewTimer(delay) select { case <-retryCtx.Done(): timer.Stop() return lastResp, wrapError(retryCtx.Err(), "retry context done") case <-timer.C: } } return lastResp, lastErr } func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) { attempt := r.Clone() attempt.retry = nil attempt.cancel = nil attempt.applied = false attempt.execCtx = nil attempt.ctx = ctx // 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。 if attempt.config != nil && attempt.config.Network.Timeout > 0 { attempt.config.Network.Timeout = 0 } if !attempt.doRaw { attempt.httpReq = attempt.httpReq.WithContext(ctx) return attempt, nil } if r.httpReq == nil { return nil, fmt.Errorf("http request is nil") } raw := r.httpReq.Clone(ctx) if r.httpReq.GetBody != nil { body, err := r.httpReq.GetBody() if err != nil { return nil, wrapError(err, "get raw request body") } raw.Body = body } else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody { return nil, fmt.Errorf("raw request body is not replayable") } attempt.httpReq = raw return attempt, nil } func (p *retryPolicy) canRetryRequest(r *Request) bool { if p.idempotentOnly && !isIdempotentMethod(r.method) { return false } return isReplayableRequest(r) } func isIdempotentMethod(method string) bool { switch method { case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace: return true default: return false } } func isReplayableRequest(r *Request) bool { if r == nil { return false } if r.doRaw { if r.httpReq == nil { return false } if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody { return true } return r.httpReq.GetBody != nil } if r.config == nil { return false } // Reader / stream body 通常不可重放,保守地不重试。 if r.config.Body.Reader != nil { return false } for _, f := range r.config.Body.Files { if f.FileData != nil || f.FilePath == "" { return false } } return true } func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool { if attempt >= maxAttempts-1 { return false } if ctx != nil && ctx.Err() != nil { return false } if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } if p.onError != nil { return p.onError(err) } return isRetryableError(err) } if resp == nil || resp.Response == nil { return false } _, ok := p.statuses[resp.StatusCode] return ok } func isRetryableError(err error) bool { var netErr net.Error if errors.As(err, &netErr) { if netErr.Timeout() { return true } if netErr.Temporary() { return true } } return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) } func (p *retryPolicy) nextDelay(attempt int) time.Duration { if p.baseDelay <= 0 { return 0 } delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt))) if p.maxDelay > 0 && delay > p.maxDelay { delay = p.maxDelay } if p.jitter <= 0 { return delay } low := 1 - p.jitter if low < 0 { low = 0 } high := 1 + p.jitter scale := low + rand.Float64()*(high-low) return time.Duration(float64(delay) * scale) }