424 lines
8.7 KiB
Go
424 lines
8.7 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type RetryOpt func(*retryPolicy) error
|
|
|
|
type retryPolicy struct {
|
|
maxRetries int
|
|
baseDelay time.Duration
|
|
maxDelay time.Duration
|
|
factor float64
|
|
jitter float64
|
|
idempotentOnly bool
|
|
statuses map[int]struct{}
|
|
onError func(error) bool
|
|
}
|
|
|
|
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
|
|
if p == nil {
|
|
return nil
|
|
}
|
|
cloned := &retryPolicy{
|
|
maxRetries: p.maxRetries,
|
|
baseDelay: p.baseDelay,
|
|
maxDelay: p.maxDelay,
|
|
factor: p.factor,
|
|
jitter: p.jitter,
|
|
idempotentOnly: p.idempotentOnly,
|
|
onError: p.onError,
|
|
}
|
|
if p.statuses != nil {
|
|
cloned.statuses = make(map[int]struct{}, len(p.statuses))
|
|
for code := range p.statuses {
|
|
cloned.statuses[code] = struct{}{}
|
|
}
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func defaultRetryPolicy(max int) *retryPolicy {
|
|
return &retryPolicy{
|
|
maxRetries: max,
|
|
baseDelay: 100 * time.Millisecond,
|
|
maxDelay: 2 * time.Second,
|
|
factor: 2.0,
|
|
jitter: 0.1,
|
|
idempotentOnly: true,
|
|
statuses: map[int]struct{}{
|
|
http.StatusRequestTimeout: {},
|
|
http.StatusTooEarly: {},
|
|
http.StatusTooManyRequests: {},
|
|
http.StatusInternalServerError: {},
|
|
http.StatusBadGateway: {},
|
|
http.StatusServiceUnavailable: {},
|
|
http.StatusGatewayTimeout: {},
|
|
},
|
|
}
|
|
}
|
|
|
|
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
|
|
if max < 0 {
|
|
return nil, fmt.Errorf("max retry must be >= 0")
|
|
}
|
|
if max == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
policy := defaultRetryPolicy(max)
|
|
for _, opt := range opts {
|
|
if opt == nil {
|
|
continue
|
|
}
|
|
if err := opt(policy); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return policy, nil
|
|
}
|
|
|
|
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
|
return func(r *Request) error {
|
|
policy, err := buildRetryPolicy(max, opts...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.retry = policy
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
|
if r.err != nil {
|
|
return r
|
|
}
|
|
policy, err := buildRetryPolicy(max, opts...)
|
|
if err != nil {
|
|
r.err = err
|
|
return r
|
|
}
|
|
r.retry = policy
|
|
return r
|
|
}
|
|
|
|
func (r *Request) DisableRetry() *Request {
|
|
if r.err != nil {
|
|
return r
|
|
}
|
|
r.retry = nil
|
|
return r
|
|
}
|
|
|
|
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
|
|
if r.err != nil {
|
|
return r
|
|
}
|
|
if opt == nil {
|
|
return r
|
|
}
|
|
if r.retry == nil {
|
|
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
|
|
return r
|
|
}
|
|
if err := opt(r.retry); err != nil {
|
|
r.err = err
|
|
}
|
|
return r
|
|
}
|
|
|
|
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
|
|
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
|
|
}
|
|
|
|
func (r *Request) SetRetryJitter(ratio float64) *Request {
|
|
return r.applyRetryOpt(WithRetryJitter(ratio))
|
|
}
|
|
|
|
func (r *Request) SetRetryStatuses(codes ...int) *Request {
|
|
return r.applyRetryOpt(WithRetryStatuses(codes...))
|
|
}
|
|
|
|
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
|
|
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
|
|
}
|
|
|
|
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
|
|
return r.applyRetryOpt(WithRetryOnError(fn))
|
|
}
|
|
|
|
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
|
|
return func(p *retryPolicy) error {
|
|
if base < 0 {
|
|
return fmt.Errorf("retry base delay must be >= 0")
|
|
}
|
|
if max < 0 {
|
|
return fmt.Errorf("retry max delay must be >= 0")
|
|
}
|
|
if factor <= 0 {
|
|
return fmt.Errorf("retry factor must be > 0")
|
|
}
|
|
p.baseDelay = base
|
|
p.maxDelay = max
|
|
p.factor = factor
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func WithRetryJitter(ratio float64) RetryOpt {
|
|
return func(p *retryPolicy) error {
|
|
if ratio < 0 || ratio > 1 {
|
|
return fmt.Errorf("retry jitter ratio must be in [0,1]")
|
|
}
|
|
p.jitter = ratio
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func WithRetryStatuses(codes ...int) RetryOpt {
|
|
return func(p *retryPolicy) error {
|
|
statuses := make(map[int]struct{}, len(codes))
|
|
for _, code := range codes {
|
|
if code < 100 || code > 999 {
|
|
return fmt.Errorf("invalid retry status code: %d", code)
|
|
}
|
|
statuses[code] = struct{}{}
|
|
}
|
|
p.statuses = statuses
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
|
|
return func(p *retryPolicy) error {
|
|
p.idempotentOnly = enabled
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func WithRetryOnError(fn func(error) bool) RetryOpt {
|
|
return func(p *retryPolicy) error {
|
|
p.onError = fn
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r *Request) hasRetryPolicy() bool {
|
|
return r.retry != nil && r.retry.maxRetries > 0
|
|
}
|
|
|
|
func (r *Request) doWithRetry() (*Response, error) {
|
|
policy := cloneRetryPolicy(r.retry)
|
|
if policy == nil || policy.maxRetries <= 0 {
|
|
return r.doOnce()
|
|
}
|
|
|
|
if !policy.canRetryRequest(r) {
|
|
return r.doOnce()
|
|
}
|
|
|
|
retryCtx := r.ctx
|
|
retryCancel := func() {}
|
|
if r.config.Network.Timeout > 0 {
|
|
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
|
|
}
|
|
defer retryCancel()
|
|
|
|
maxAttempts := policy.maxRetries + 1
|
|
var lastResp *Response
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
|
attemptReq, err := r.newRetryAttempt(retryCtx)
|
|
if err != nil {
|
|
return nil, wrapError(err, "build retry attempt")
|
|
}
|
|
|
|
resp, err := attemptReq.doOnce()
|
|
if resp != nil {
|
|
resp.request = r
|
|
}
|
|
|
|
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
|
|
return resp, err
|
|
}
|
|
|
|
lastResp = resp
|
|
lastErr = err
|
|
if lastResp != nil {
|
|
_ = lastResp.Close()
|
|
}
|
|
|
|
delay := policy.nextDelay(attempt)
|
|
if delay <= 0 {
|
|
continue
|
|
}
|
|
|
|
timer := time.NewTimer(delay)
|
|
select {
|
|
case <-retryCtx.Done():
|
|
timer.Stop()
|
|
return lastResp, wrapError(retryCtx.Err(), "retry context done")
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
|
|
return lastResp, lastErr
|
|
}
|
|
|
|
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
|
|
attempt := r.Clone()
|
|
attempt.retry = nil
|
|
attempt.cancel = nil
|
|
attempt.applied = false
|
|
attempt.execCtx = nil
|
|
attempt.ctx = ctx
|
|
|
|
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
|
|
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
|
|
attempt.config.Network.Timeout = 0
|
|
}
|
|
|
|
if !attempt.doRaw {
|
|
attempt.httpReq = attempt.httpReq.WithContext(ctx)
|
|
return attempt, nil
|
|
}
|
|
|
|
if r.httpReq == nil {
|
|
return nil, fmt.Errorf("http request is nil")
|
|
}
|
|
|
|
raw := r.httpReq.Clone(ctx)
|
|
if r.httpReq.GetBody != nil {
|
|
body, err := r.httpReq.GetBody()
|
|
if err != nil {
|
|
return nil, wrapError(err, "get raw request body")
|
|
}
|
|
raw.Body = body
|
|
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
|
|
return nil, fmt.Errorf("raw request body is not replayable")
|
|
}
|
|
|
|
attempt.httpReq = raw
|
|
return attempt, nil
|
|
}
|
|
|
|
func (p *retryPolicy) canRetryRequest(r *Request) bool {
|
|
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
|
return false
|
|
}
|
|
return isReplayableRequest(r)
|
|
}
|
|
|
|
func isIdempotentMethod(method string) bool {
|
|
switch method {
|
|
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isReplayableRequest(r *Request) bool {
|
|
if r == nil {
|
|
return false
|
|
}
|
|
|
|
if r.doRaw {
|
|
if r.httpReq == nil {
|
|
return false
|
|
}
|
|
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
|
|
return true
|
|
}
|
|
return r.httpReq.GetBody != nil
|
|
}
|
|
|
|
if r.config == nil {
|
|
return false
|
|
}
|
|
|
|
// Reader / stream body 通常不可重放,保守地不重试。
|
|
if r.config.Body.Reader != nil {
|
|
return false
|
|
}
|
|
|
|
for _, f := range r.config.Body.Files {
|
|
if f.FileData != nil || f.FilePath == "" {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
|
if attempt >= maxAttempts-1 {
|
|
return false
|
|
}
|
|
if ctx != nil && ctx.Err() != nil {
|
|
return false
|
|
}
|
|
|
|
if err != nil {
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
return false
|
|
}
|
|
if p.onError != nil {
|
|
return p.onError(err)
|
|
}
|
|
return isRetryableError(err)
|
|
}
|
|
|
|
if resp == nil || resp.Response == nil {
|
|
return false
|
|
}
|
|
_, ok := p.statuses[resp.StatusCode]
|
|
return ok
|
|
}
|
|
|
|
func isRetryableError(err error) bool {
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) {
|
|
if netErr.Timeout() {
|
|
return true
|
|
}
|
|
if netErr.Temporary() {
|
|
return true
|
|
}
|
|
}
|
|
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
|
|
}
|
|
|
|
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
|
|
if p.baseDelay <= 0 {
|
|
return 0
|
|
}
|
|
|
|
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
|
|
if p.maxDelay > 0 && delay > p.maxDelay {
|
|
delay = p.maxDelay
|
|
}
|
|
|
|
if p.jitter <= 0 {
|
|
return delay
|
|
}
|
|
|
|
low := 1 - p.jitter
|
|
if low < 0 {
|
|
low = 0
|
|
}
|
|
high := 1 + p.jitter
|
|
scale := low + rand.Float64()*(high-low)
|
|
return time.Duration(float64(delay) * scale)
|
|
}
|