- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
417 lines
11 KiB
Go
417 lines
11 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
const dynamicTransportCacheMaxEntries = 64
|
||
|
||
type dynamicTransportCacheKey struct {
|
||
proxyKey string
|
||
dialTimeout time.Duration
|
||
customIPs string
|
||
customDNS string
|
||
tlsServerName string
|
||
skipVerify bool
|
||
}
|
||
|
||
// Transport 自定义 Transport(支持请求级配置)
|
||
type Transport struct {
|
||
base *http.Transport
|
||
dynamicCache map[dynamicTransportCacheKey]*http.Transport
|
||
dynamicCacheOrder []dynamicTransportCacheKey
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// RoundTrip 实现 http.RoundTripper 接口
|
||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
t.ensureBase()
|
||
|
||
// 提取请求级别的配置
|
||
reqCtx := getRequestContext(req.Context())
|
||
traceState := getTraceState(req.Context())
|
||
execReq := req
|
||
execReqCtx := reqCtx
|
||
var targetAddrs []string
|
||
|
||
// 优先级1:完全自定义的 transport
|
||
if execReqCtx.Transport != nil {
|
||
return execReqCtx.Transport.RoundTrip(execReq)
|
||
}
|
||
|
||
var err error
|
||
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 优先级2:需要动态配置
|
||
if needsDynamicTransport(execReqCtx) {
|
||
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
|
||
if len(targetAddrs) > 0 {
|
||
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
|
||
}
|
||
return dynamicTransport.RoundTrip(execReq)
|
||
}
|
||
|
||
// 优先级3:使用基础 transport
|
||
t.mu.RLock()
|
||
baseTransport := t.base
|
||
t.mu.RUnlock()
|
||
if len(targetAddrs) > 0 {
|
||
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
|
||
}
|
||
return baseTransport.RoundTrip(execReq)
|
||
}
|
||
|
||
func newBaseHTTPTransport() *http.Transport {
|
||
return &http.Transport{
|
||
ForceAttemptHTTP2: true,
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
}
|
||
}
|
||
|
||
func (t *Transport) ensureBase() {
|
||
if t.base != nil {
|
||
return
|
||
}
|
||
t.mu.Lock()
|
||
defer t.mu.Unlock()
|
||
t.ensureBaseLocked()
|
||
}
|
||
|
||
func (t *Transport) ensureBaseLocked() {
|
||
if t.base == nil {
|
||
t.base = newBaseHTTPTransport()
|
||
}
|
||
}
|
||
|
||
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||
if key, ok := newDynamicTransportCacheKey(rc); ok {
|
||
return t.getOrCreateCachedDynamicTransport(key, rc)
|
||
}
|
||
return t.buildDynamicTransport(rc, traceState)
|
||
}
|
||
|
||
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
|
||
t.mu.RLock()
|
||
if transport := t.dynamicCache[key]; transport != nil {
|
||
t.mu.RUnlock()
|
||
return transport
|
||
}
|
||
t.mu.RUnlock()
|
||
|
||
t.mu.Lock()
|
||
defer t.mu.Unlock()
|
||
|
||
t.ensureBaseLocked()
|
||
if transport := t.dynamicCache[key]; transport != nil {
|
||
return transport
|
||
}
|
||
|
||
transport := buildDynamicTransportFromBase(t.base, rc, nil)
|
||
if t.dynamicCache == nil {
|
||
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
|
||
}
|
||
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
|
||
oldestKey := t.dynamicCacheOrder[0]
|
||
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
|
||
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
|
||
oldest.CloseIdleConnections()
|
||
delete(t.dynamicCache, oldestKey)
|
||
}
|
||
}
|
||
t.dynamicCache[key] = transport
|
||
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
|
||
return transport
|
||
}
|
||
|
||
func (t *Transport) resetDynamicTransportCacheLocked() {
|
||
for _, key := range t.dynamicCacheOrder {
|
||
if transport := t.dynamicCache[key]; transport != nil {
|
||
transport.CloseIdleConnections()
|
||
}
|
||
}
|
||
t.dynamicCache = nil
|
||
t.dynamicCacheOrder = nil
|
||
}
|
||
|
||
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
|
||
if rc == nil {
|
||
return dynamicTransportCacheKey{}, false
|
||
}
|
||
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
|
||
return dynamicTransportCacheKey{}, false
|
||
}
|
||
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
|
||
return dynamicTransportCacheKey{}, false
|
||
}
|
||
|
||
key := dynamicTransportCacheKey{
|
||
proxyKey: normalizeProxyCacheKey(rc.Proxy),
|
||
dialTimeout: rc.DialTimeout,
|
||
customIPs: serializeTransportCacheList(rc.CustomIP),
|
||
customDNS: serializeTransportCacheList(rc.CustomDNS),
|
||
tlsServerName: effectiveTLSServerName(rc),
|
||
}
|
||
if rc.TLSConfig != nil {
|
||
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
|
||
}
|
||
return key, true
|
||
}
|
||
|
||
func normalizeProxyCacheKey(proxy string) string {
|
||
if proxy == "" {
|
||
return ""
|
||
}
|
||
proxyURL, err := parseProxyURL(proxy)
|
||
if err != nil {
|
||
return "\x00invalid:" + proxy
|
||
}
|
||
return proxyURL.String()
|
||
}
|
||
|
||
func serializeTransportCacheList(values []string) string {
|
||
if len(values) == 0 {
|
||
return ""
|
||
}
|
||
var builder strings.Builder
|
||
for _, value := range values {
|
||
builder.WriteString(value)
|
||
builder.WriteByte(0)
|
||
}
|
||
return builder.String()
|
||
}
|
||
|
||
func effectiveTLSServerName(rc *RequestContext) string {
|
||
if rc == nil {
|
||
return ""
|
||
}
|
||
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
|
||
return rc.TLSConfig.ServerName
|
||
}
|
||
return rc.TLSServerName
|
||
}
|
||
|
||
// buildDynamicTransport 构建动态 Transport
|
||
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||
t.ensureBase()
|
||
t.mu.RLock()
|
||
baseTransport := t.base
|
||
t.mu.RUnlock()
|
||
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
|
||
}
|
||
|
||
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
|
||
transport := baseTransport.Clone()
|
||
|
||
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
||
if rc.TLSConfig != nil {
|
||
transport.TLSClientConfig = rc.TLSConfig
|
||
}
|
||
|
||
// 应用代理配置
|
||
if rc.Proxy != "" {
|
||
proxyURL, err := parseProxyURL(rc.Proxy)
|
||
if err != nil {
|
||
transport.Proxy = func(*http.Request) (*url.URL, error) {
|
||
return nil, err
|
||
}
|
||
} else {
|
||
transport.Proxy = http.ProxyURL(proxyURL)
|
||
}
|
||
}
|
||
|
||
// 应用自定义 Dial 函数
|
||
if rc.DialFn != nil {
|
||
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
|
||
dialFn := rc.DialFn
|
||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
if traceState.hooks.ConnectStart != nil {
|
||
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||
}
|
||
conn, err := dialFn(ctx, network, addr)
|
||
if traceState.hooks.ConnectDone != nil {
|
||
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||
}
|
||
return conn, err
|
||
}
|
||
} else {
|
||
transport.DialContext = rc.DialFn
|
||
}
|
||
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
||
// 使用默认 Dial 函数(会从 context 读取配置)
|
||
transport.DialContext = defaultDialFunc
|
||
transport.DialTLSContext = defaultDialTLSFunc
|
||
}
|
||
|
||
return transport
|
||
}
|
||
|
||
// Base 获取基础 Transport
|
||
func (t *Transport) Base() *http.Transport {
|
||
t.mu.RLock()
|
||
defer t.mu.RUnlock()
|
||
return t.base
|
||
}
|
||
|
||
// SetBase 设置基础 Transport
|
||
func (t *Transport) SetBase(base *http.Transport) {
|
||
t.mu.Lock()
|
||
t.base = base
|
||
t.resetDynamicTransportCacheLocked()
|
||
t.mu.Unlock()
|
||
}
|
||
|
||
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
|
||
if req == nil || req.URL == nil || reqCtx == nil {
|
||
return req, reqCtx, nil, nil
|
||
}
|
||
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
|
||
return req, reqCtx, nil, nil
|
||
}
|
||
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
|
||
return req, reqCtx, nil, nil
|
||
}
|
||
|
||
host := req.URL.Hostname()
|
||
if host == "" {
|
||
return req, reqCtx, nil, nil
|
||
}
|
||
|
||
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
|
||
if err != nil {
|
||
return nil, nil, nil, err
|
||
}
|
||
if len(targetAddrs) == 0 {
|
||
return req, reqCtx, nil, nil
|
||
}
|
||
|
||
execReqCtx := *reqCtx
|
||
execReqCtx.CustomIP = nil
|
||
execReqCtx.CustomDNS = nil
|
||
execReqCtx.LookupIPFn = nil
|
||
|
||
if req.URL.Scheme == "https" {
|
||
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
|
||
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
|
||
execReqCtx.TLSConfigCacheable = true
|
||
}
|
||
}
|
||
|
||
execCtx := clearTargetResolutionContext(req.Context())
|
||
execReq := req.Clone(execCtx)
|
||
execReq.Host = req.Host
|
||
if len(targetAddrs) == 1 {
|
||
execReq.URL.Host = targetAddrs[0]
|
||
return execReq, &execReqCtx, nil, nil
|
||
}
|
||
|
||
return execReq, &execReqCtx, targetAddrs, nil
|
||
}
|
||
|
||
func clearTargetResolutionContext(ctx context.Context) context.Context {
|
||
if v := ctx.Value(ctxKeyRequestContext); v != nil {
|
||
if rc, ok := v.(*RequestContext); ok && rc != nil {
|
||
cloned := cloneRequestContext(rc)
|
||
cloned.CustomIP = nil
|
||
cloned.CustomDNS = nil
|
||
cloned.LookupIPFn = nil
|
||
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
|
||
}
|
||
}
|
||
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
|
||
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
|
||
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
|
||
return ctx
|
||
}
|
||
|
||
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
|
||
if serverName == "" {
|
||
return cfg
|
||
}
|
||
if cfg != nil {
|
||
if cfg.ServerName != "" {
|
||
return cfg
|
||
}
|
||
cloned := cfg.Clone()
|
||
cloned.ServerName = serverName
|
||
return cloned
|
||
}
|
||
return &tls.Config{
|
||
ServerName: serverName,
|
||
NextProtos: []string{"h2", "http/1.1"},
|
||
}
|
||
}
|
||
|
||
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
|
||
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
|
||
return rt.RoundTrip(baseReq)
|
||
}
|
||
|
||
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
|
||
targetAddrs = targetAddrs[:1]
|
||
}
|
||
|
||
var lastErr error
|
||
for _, targetAddr := range targetAddrs {
|
||
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
resp, err := rt.RoundTrip(attemptReq)
|
||
if err == nil {
|
||
return resp, nil
|
||
}
|
||
lastErr = err
|
||
}
|
||
|
||
return nil, lastErr
|
||
}
|
||
|
||
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
|
||
if req == nil {
|
||
return false
|
||
}
|
||
if !isIdempotentMethod(req.Method) {
|
||
return false
|
||
}
|
||
if req.Body == nil || req.Body == http.NoBody {
|
||
return true
|
||
}
|
||
return req.GetBody != nil
|
||
}
|
||
|
||
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
|
||
req := baseReq.Clone(baseReq.Context())
|
||
|
||
switch {
|
||
case baseReq.Body == nil || baseReq.Body == http.NoBody:
|
||
req.Body = baseReq.Body
|
||
case baseReq.GetBody != nil:
|
||
body, err := baseReq.GetBody()
|
||
if err != nil {
|
||
return nil, wrapError(err, "clone request body for resolved target")
|
||
}
|
||
req.Body = body
|
||
default:
|
||
req.Body = baseReq.Body
|
||
}
|
||
|
||
req.URL.Host = targetAddr
|
||
req.Host = baseReq.Host
|
||
return req, nil
|
||
}
|