starnet/retry.go
starainrt b5bd7595a1
1. 优化ping功能
2. 新增重试机制
3. 优化错误处理逻辑
2026-03-19 16:42:45 +08:00

424 lines
8.7 KiB
Go

package starnet
import (
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := r.ctx
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
if r.httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
raw := r.httpReq.Clone(ctx)
if r.httpReq.GetBody != nil {
body, err := r.httpReq.GetBody()
if err != nil {
return nil, wrapError(err, "get raw request body")
}
raw.Body = body
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
return nil, fmt.Errorf("raw request body is not replayable")
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
// Reader / stream body 通常不可重放,保守地不重试。
if r.config.Body.Reader != nil {
return false
}
for _, f := range r.config.Body.Files {
if f.FileData != nil || f.FilePath == "" {
return false
}
}
return true
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}