- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
467 lines
10 KiB
Go
467 lines
10 KiB
Go
package starnet
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"math/rand"
|
||
"net"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type RetryOpt func(*retryPolicy) error
|
||
|
||
type retryPolicy struct {
|
||
maxRetries int
|
||
baseDelay time.Duration
|
||
maxDelay time.Duration
|
||
factor float64
|
||
jitter float64
|
||
idempotentOnly bool
|
||
statuses map[int]struct{}
|
||
onError func(error) bool
|
||
}
|
||
|
||
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
|
||
if p == nil {
|
||
return nil
|
||
}
|
||
cloned := &retryPolicy{
|
||
maxRetries: p.maxRetries,
|
||
baseDelay: p.baseDelay,
|
||
maxDelay: p.maxDelay,
|
||
factor: p.factor,
|
||
jitter: p.jitter,
|
||
idempotentOnly: p.idempotentOnly,
|
||
onError: p.onError,
|
||
}
|
||
if p.statuses != nil {
|
||
cloned.statuses = make(map[int]struct{}, len(p.statuses))
|
||
for code := range p.statuses {
|
||
cloned.statuses[code] = struct{}{}
|
||
}
|
||
}
|
||
return cloned
|
||
}
|
||
|
||
func defaultRetryPolicy(max int) *retryPolicy {
|
||
return &retryPolicy{
|
||
maxRetries: max,
|
||
baseDelay: 100 * time.Millisecond,
|
||
maxDelay: 2 * time.Second,
|
||
factor: 2.0,
|
||
jitter: 0.1,
|
||
idempotentOnly: true,
|
||
statuses: map[int]struct{}{
|
||
http.StatusRequestTimeout: {},
|
||
http.StatusTooEarly: {},
|
||
http.StatusTooManyRequests: {},
|
||
http.StatusInternalServerError: {},
|
||
http.StatusBadGateway: {},
|
||
http.StatusServiceUnavailable: {},
|
||
http.StatusGatewayTimeout: {},
|
||
},
|
||
}
|
||
}
|
||
|
||
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
|
||
if max < 0 {
|
||
return nil, fmt.Errorf("max retry must be >= 0")
|
||
}
|
||
if max == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
policy := defaultRetryPolicy(max)
|
||
for _, opt := range opts {
|
||
if opt == nil {
|
||
continue
|
||
}
|
||
if err := opt(policy); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
return policy, nil
|
||
}
|
||
|
||
// WithRetry 为请求启用自动重试。
|
||
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||
// 以避免请求体已落地后再次发送。
|
||
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
||
return func(r *Request) error {
|
||
policy, err := buildRetryPolicy(max, opts...)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
r.retry = policy
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// SetRetry 为请求启用自动重试。
|
||
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||
// 以避免请求体已落地后再次发送。
|
||
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
||
if r.err != nil {
|
||
return r
|
||
}
|
||
policy, err := buildRetryPolicy(max, opts...)
|
||
if err != nil {
|
||
r.err = err
|
||
return r
|
||
}
|
||
r.retry = policy
|
||
return r
|
||
}
|
||
|
||
func (r *Request) DisableRetry() *Request {
|
||
if r.err != nil {
|
||
return r
|
||
}
|
||
r.retry = nil
|
||
return r
|
||
}
|
||
|
||
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
|
||
if r.err != nil {
|
||
return r
|
||
}
|
||
if opt == nil {
|
||
return r
|
||
}
|
||
if r.retry == nil {
|
||
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
|
||
return r
|
||
}
|
||
if err := opt(r.retry); err != nil {
|
||
r.err = err
|
||
}
|
||
return r
|
||
}
|
||
|
||
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
|
||
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
|
||
}
|
||
|
||
func (r *Request) SetRetryJitter(ratio float64) *Request {
|
||
return r.applyRetryOpt(WithRetryJitter(ratio))
|
||
}
|
||
|
||
func (r *Request) SetRetryStatuses(codes ...int) *Request {
|
||
return r.applyRetryOpt(WithRetryStatuses(codes...))
|
||
}
|
||
|
||
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
|
||
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
|
||
}
|
||
|
||
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
|
||
return r.applyRetryOpt(WithRetryOnError(fn))
|
||
}
|
||
|
||
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
|
||
return func(p *retryPolicy) error {
|
||
if base < 0 {
|
||
return fmt.Errorf("retry base delay must be >= 0")
|
||
}
|
||
if max < 0 {
|
||
return fmt.Errorf("retry max delay must be >= 0")
|
||
}
|
||
if factor <= 0 {
|
||
return fmt.Errorf("retry factor must be > 0")
|
||
}
|
||
p.baseDelay = base
|
||
p.maxDelay = max
|
||
p.factor = factor
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func WithRetryJitter(ratio float64) RetryOpt {
|
||
return func(p *retryPolicy) error {
|
||
if ratio < 0 || ratio > 1 {
|
||
return fmt.Errorf("retry jitter ratio must be in [0,1]")
|
||
}
|
||
p.jitter = ratio
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func WithRetryStatuses(codes ...int) RetryOpt {
|
||
return func(p *retryPolicy) error {
|
||
statuses := make(map[int]struct{}, len(codes))
|
||
for _, code := range codes {
|
||
if code < 100 || code > 999 {
|
||
return fmt.Errorf("invalid retry status code: %d", code)
|
||
}
|
||
statuses[code] = struct{}{}
|
||
}
|
||
p.statuses = statuses
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
|
||
return func(p *retryPolicy) error {
|
||
p.idempotentOnly = enabled
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func WithRetryOnError(fn func(error) bool) RetryOpt {
|
||
return func(p *retryPolicy) error {
|
||
p.onError = fn
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (r *Request) hasRetryPolicy() bool {
|
||
return r.retry != nil && r.retry.maxRetries > 0
|
||
}
|
||
|
||
func (r *Request) doWithRetry() (*Response, error) {
|
||
policy := cloneRetryPolicy(r.retry)
|
||
if policy == nil || policy.maxRetries <= 0 {
|
||
return r.doOnce()
|
||
}
|
||
|
||
if !policy.canRetryRequest(r) {
|
||
return r.doOnce()
|
||
}
|
||
|
||
retryCtx := normalizeContext(r.ctx)
|
||
retryCancel := func() {}
|
||
if r.config.Network.Timeout > 0 {
|
||
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
|
||
}
|
||
defer retryCancel()
|
||
|
||
maxAttempts := policy.maxRetries + 1
|
||
var lastResp *Response
|
||
var lastErr error
|
||
|
||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||
attemptNo := attempt + 1
|
||
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
|
||
Attempt: attemptNo,
|
||
MaxAttempts: maxAttempts,
|
||
})
|
||
|
||
attemptReq, err := r.newRetryAttempt(retryCtx)
|
||
if err != nil {
|
||
return nil, wrapError(err, "build retry attempt")
|
||
}
|
||
|
||
resp, err := attemptReq.doOnce()
|
||
if resp != nil {
|
||
resp.request = r
|
||
}
|
||
|
||
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
|
||
statusCode := 0
|
||
if resp != nil {
|
||
statusCode = resp.StatusCode
|
||
}
|
||
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
|
||
Attempt: attemptNo,
|
||
MaxAttempts: maxAttempts,
|
||
StatusCode: statusCode,
|
||
Err: err,
|
||
WillRetry: willRetry,
|
||
})
|
||
if !willRetry {
|
||
return resp, err
|
||
}
|
||
|
||
lastResp = resp
|
||
lastErr = err
|
||
if lastResp != nil {
|
||
_ = lastResp.Close()
|
||
}
|
||
|
||
delay := policy.nextDelay(attempt)
|
||
if delay <= 0 {
|
||
continue
|
||
}
|
||
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
|
||
Attempt: attemptNo,
|
||
Delay: delay,
|
||
})
|
||
|
||
timer := time.NewTimer(delay)
|
||
select {
|
||
case <-retryCtx.Done():
|
||
timer.Stop()
|
||
return lastResp, wrapError(retryCtx.Err(), "retry context done")
|
||
case <-timer.C:
|
||
}
|
||
}
|
||
|
||
return lastResp, lastErr
|
||
}
|
||
|
||
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
|
||
attempt := r.Clone()
|
||
attempt.retry = nil
|
||
attempt.cancel = nil
|
||
attempt.applied = false
|
||
attempt.execCtx = nil
|
||
attempt.ctx = ctx
|
||
|
||
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
|
||
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
|
||
attempt.config.Network.Timeout = 0
|
||
}
|
||
|
||
if !attempt.doRaw {
|
||
attempt.httpReq = attempt.httpReq.WithContext(ctx)
|
||
return attempt, nil
|
||
}
|
||
|
||
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
attempt.httpReq = raw
|
||
return attempt, nil
|
||
}
|
||
|
||
func (p *retryPolicy) canRetryRequest(r *Request) bool {
|
||
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
||
return false
|
||
}
|
||
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
|
||
return false
|
||
}
|
||
return isReplayableRequest(r)
|
||
}
|
||
|
||
func isIdempotentMethod(method string) bool {
|
||
switch method {
|
||
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func isReplayableRequest(r *Request) bool {
|
||
if r == nil {
|
||
return false
|
||
}
|
||
|
||
if r.doRaw {
|
||
if r.httpReq == nil {
|
||
return false
|
||
}
|
||
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
|
||
return true
|
||
}
|
||
return r.httpReq.GetBody != nil
|
||
}
|
||
|
||
if r.config == nil {
|
||
return false
|
||
}
|
||
|
||
return isReplayableConfiguredBody(r.config.Body)
|
||
}
|
||
|
||
func hasReaderRequestBody(r *Request) bool {
|
||
if r == nil || r.config == nil {
|
||
return false
|
||
}
|
||
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
|
||
}
|
||
|
||
func isReplayableConfiguredBody(body BodyConfig) bool {
|
||
switch body.Mode {
|
||
case bodyModeReader:
|
||
return isReplayableBodyReader(body.Reader)
|
||
case bodyModeMultipart:
|
||
for _, file := range body.Files {
|
||
if file.FileData != nil || file.FilePath == "" {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
func isReplayableBodyReader(reader io.Reader) bool {
|
||
switch reader.(type) {
|
||
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
||
if attempt >= maxAttempts-1 {
|
||
return false
|
||
}
|
||
if ctx != nil && ctx.Err() != nil {
|
||
return false
|
||
}
|
||
|
||
if err != nil {
|
||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||
return false
|
||
}
|
||
if p.onError != nil {
|
||
return p.onError(err)
|
||
}
|
||
return isRetryableError(err)
|
||
}
|
||
|
||
if resp == nil || resp.Response == nil {
|
||
return false
|
||
}
|
||
_, ok := p.statuses[resp.StatusCode]
|
||
return ok
|
||
}
|
||
|
||
func isRetryableError(err error) bool {
|
||
var netErr net.Error
|
||
if errors.As(err, &netErr) {
|
||
if netErr.Timeout() {
|
||
return true
|
||
}
|
||
if netErr.Temporary() {
|
||
return true
|
||
}
|
||
}
|
||
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
|
||
}
|
||
|
||
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
|
||
if p.baseDelay <= 0 {
|
||
return 0
|
||
}
|
||
|
||
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
|
||
if p.maxDelay > 0 && delay > p.maxDelay {
|
||
delay = p.maxDelay
|
||
}
|
||
|
||
if p.jitter <= 0 {
|
||
return delay
|
||
}
|
||
|
||
low := 1 - p.jitter
|
||
if low < 0 {
|
||
low = 0
|
||
}
|
||
high := 1 + p.jitter
|
||
scale := low + rand.Float64()*(high-low)
|
||
return time.Duration(float64(delay) * scale)
|
||
}
|