starnet/transport.go
starainrt 732e81316c
fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界
- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
2026-04-19 15:39:51 +08:00

417 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
const dynamicTransportCacheMaxEntries = 64
type dynamicTransportCacheKey struct {
proxyKey string
dialTimeout time.Duration
customIPs string
customDNS string
tlsServerName string
skipVerify bool
}
// Transport 自定义 Transport支持请求级配置
type Transport struct {
base *http.Transport
dynamicCache map[dynamicTransportCacheKey]*http.Transport
dynamicCacheOrder []dynamicTransportCacheKey
mu sync.RWMutex
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
t.ensureBase()
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
traceState := getTraceState(req.Context())
execReq := req
execReqCtx := reqCtx
var targetAddrs []string
// 优先级1完全自定义的 transport
if execReqCtx.Transport != nil {
return execReqCtx.Transport.RoundTrip(execReq)
}
var err error
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
if err != nil {
return nil, err
}
// 优先级2需要动态配置
if needsDynamicTransport(execReqCtx) {
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
}
return dynamicTransport.RoundTrip(execReq)
}
// 优先级3使用基础 transport
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
}
return baseTransport.RoundTrip(execReq)
}
func newBaseHTTPTransport() *http.Transport {
return &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
func (t *Transport) ensureBase() {
if t.base != nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
}
func (t *Transport) ensureBaseLocked() {
if t.base == nil {
t.base = newBaseHTTPTransport()
}
}
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
if key, ok := newDynamicTransportCacheKey(rc); ok {
return t.getOrCreateCachedDynamicTransport(key, rc)
}
return t.buildDynamicTransport(rc, traceState)
}
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
t.mu.RLock()
if transport := t.dynamicCache[key]; transport != nil {
t.mu.RUnlock()
return transport
}
t.mu.RUnlock()
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
if transport := t.dynamicCache[key]; transport != nil {
return transport
}
transport := buildDynamicTransportFromBase(t.base, rc, nil)
if t.dynamicCache == nil {
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
}
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
oldestKey := t.dynamicCacheOrder[0]
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
oldest.CloseIdleConnections()
delete(t.dynamicCache, oldestKey)
}
}
t.dynamicCache[key] = transport
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
return transport
}
func (t *Transport) resetDynamicTransportCacheLocked() {
for _, key := range t.dynamicCacheOrder {
if transport := t.dynamicCache[key]; transport != nil {
transport.CloseIdleConnections()
}
}
t.dynamicCache = nil
t.dynamicCacheOrder = nil
}
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
if rc == nil {
return dynamicTransportCacheKey{}, false
}
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
return dynamicTransportCacheKey{}, false
}
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
return dynamicTransportCacheKey{}, false
}
key := dynamicTransportCacheKey{
proxyKey: normalizeProxyCacheKey(rc.Proxy),
dialTimeout: rc.DialTimeout,
customIPs: serializeTransportCacheList(rc.CustomIP),
customDNS: serializeTransportCacheList(rc.CustomDNS),
tlsServerName: effectiveTLSServerName(rc),
}
if rc.TLSConfig != nil {
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
}
return key, true
}
func normalizeProxyCacheKey(proxy string) string {
if proxy == "" {
return ""
}
proxyURL, err := parseProxyURL(proxy)
if err != nil {
return "\x00invalid:" + proxy
}
return proxyURL.String()
}
func serializeTransportCacheList(values []string) string {
if len(values) == 0 {
return ""
}
var builder strings.Builder
for _, value := range values {
builder.WriteString(value)
builder.WriteByte(0)
}
return builder.String()
}
func effectiveTLSServerName(rc *RequestContext) string {
if rc == nil {
return ""
}
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
return rc.TLSConfig.ServerName
}
return rc.TLSServerName
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
t.ensureBase()
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
}
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
transport := baseTransport.Clone()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
transport.TLSClientConfig = rc.TLSConfig
}
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := parseProxyURL(rc.Proxy)
if err != nil {
transport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, err
}
} else {
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
dialFn := rc.DialFn
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if traceState.hooks.ConnectStart != nil {
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
conn, err := dialFn(ctx, network, addr)
if traceState.hooks.ConnectDone != nil {
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
return conn, err
}
} else {
transport.DialContext = rc.DialFn
}
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
transport.DialTLSContext = defaultDialTLSFunc
}
return transport
}
// Base 获取基础 Transport
func (t *Transport) Base() *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base
}
// SetBase 设置基础 Transport
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.resetDynamicTransportCacheLocked()
t.mu.Unlock()
}
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
if req == nil || req.URL == nil || reqCtx == nil {
return req, reqCtx, nil, nil
}
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
return req, reqCtx, nil, nil
}
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
return req, reqCtx, nil, nil
}
host := req.URL.Hostname()
if host == "" {
return req, reqCtx, nil, nil
}
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
if err != nil {
return nil, nil, nil, err
}
if len(targetAddrs) == 0 {
return req, reqCtx, nil, nil
}
execReqCtx := *reqCtx
execReqCtx.CustomIP = nil
execReqCtx.CustomDNS = nil
execReqCtx.LookupIPFn = nil
if req.URL.Scheme == "https" {
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
execReqCtx.TLSConfigCacheable = true
}
}
execCtx := clearTargetResolutionContext(req.Context())
execReq := req.Clone(execCtx)
execReq.Host = req.Host
if len(targetAddrs) == 1 {
execReq.URL.Host = targetAddrs[0]
return execReq, &execReqCtx, nil, nil
}
return execReq, &execReqCtx, targetAddrs, nil
}
func clearTargetResolutionContext(ctx context.Context) context.Context {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
cloned := cloneRequestContext(rc)
cloned.CustomIP = nil
cloned.CustomDNS = nil
cloned.LookupIPFn = nil
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
}
}
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
return ctx
}
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
if serverName == "" {
return cfg
}
if cfg != nil {
if cfg.ServerName != "" {
return cfg
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
return &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
}
}
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
return rt.RoundTrip(baseReq)
}
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
targetAddrs = targetAddrs[:1]
}
var lastErr error
for _, targetAddr := range targetAddrs {
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
if err != nil {
return nil, err
}
resp, err := rt.RoundTrip(attemptReq)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
if req == nil {
return false
}
if !isIdempotentMethod(req.Method) {
return false
}
if req.Body == nil || req.Body == http.NoBody {
return true
}
return req.GetBody != nil
}
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
req := baseReq.Clone(baseReq.Context())
switch {
case baseReq.Body == nil || baseReq.Body == http.NoBody:
req.Body = baseReq.Body
case baseReq.GetBody != nil:
body, err := baseReq.GetBody()
if err != nil {
return nil, wrapError(err, "clone request body for resolved target")
}
req.Body = body
default:
req.Body = baseReq.Body
}
req.URL.Host = targetAddr
req.Host = baseReq.Host
return req, nil
}