package starnet import ( "context" "crypto/tls" "net" "net/http" "net/url" "strings" "sync" "time" ) const dynamicTransportCacheMaxEntries = 64 type dynamicTransportCacheKey struct { proxyKey string dialTimeout time.Duration customIPs string customDNS string tlsServerName string skipVerify bool } // Transport 自定义 Transport(支持请求级配置) type Transport struct { base *http.Transport dynamicCache map[dynamicTransportCacheKey]*http.Transport dynamicCacheOrder []dynamicTransportCacheKey mu sync.RWMutex } // RoundTrip 实现 http.RoundTripper 接口 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { t.ensureBase() // 提取请求级别的配置 reqCtx := getRequestContext(req.Context()) traceState := getTraceState(req.Context()) execReq := req execReqCtx := reqCtx var targetAddrs []string // 优先级1:完全自定义的 transport if execReqCtx.Transport != nil { return execReqCtx.Transport.RoundTrip(execReq) } var err error execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState) if err != nil { return nil, err } // 优先级2:需要动态配置 if needsDynamicTransport(execReqCtx) { dynamicTransport := t.getDynamicTransport(execReqCtx, traceState) if len(targetAddrs) > 0 { return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs) } return dynamicTransport.RoundTrip(execReq) } // 优先级3:使用基础 transport t.mu.RLock() baseTransport := t.base t.mu.RUnlock() if len(targetAddrs) > 0 { return roundTripResolvedTargets(baseTransport, execReq, targetAddrs) } return baseTransport.RoundTrip(execReq) } func newBaseHTTPTransport() *http.Transport { return &http.Transport{ ForceAttemptHTTP2: true, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } } func (t *Transport) ensureBase() { if t.base != nil { return } t.mu.Lock() defer t.mu.Unlock() t.ensureBaseLocked() } func (t *Transport) ensureBaseLocked() { if t.base == nil { t.base = newBaseHTTPTransport() } } func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport { if key, ok := newDynamicTransportCacheKey(rc); ok { return t.getOrCreateCachedDynamicTransport(key, rc) } return t.buildDynamicTransport(rc, traceState) } func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport { t.mu.RLock() if transport := t.dynamicCache[key]; transport != nil { t.mu.RUnlock() return transport } t.mu.RUnlock() t.mu.Lock() defer t.mu.Unlock() t.ensureBaseLocked() if transport := t.dynamicCache[key]; transport != nil { return transport } transport := buildDynamicTransportFromBase(t.base, rc, nil) if t.dynamicCache == nil { t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport) } if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries { oldestKey := t.dynamicCacheOrder[0] t.dynamicCacheOrder = t.dynamicCacheOrder[1:] if oldest := t.dynamicCache[oldestKey]; oldest != nil { oldest.CloseIdleConnections() delete(t.dynamicCache, oldestKey) } } t.dynamicCache[key] = transport t.dynamicCacheOrder = append(t.dynamicCacheOrder, key) return transport } func (t *Transport) resetDynamicTransportCacheLocked() { for _, key := range t.dynamicCacheOrder { if transport := t.dynamicCache[key]; transport != nil { transport.CloseIdleConnections() } } t.dynamicCache = nil t.dynamicCacheOrder = nil } func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) { if rc == nil { return dynamicTransportCacheKey{}, false } if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil { return dynamicTransportCacheKey{}, false } if rc.TLSConfig != nil && !rc.TLSConfigCacheable { return dynamicTransportCacheKey{}, false } key := dynamicTransportCacheKey{ proxyKey: normalizeProxyCacheKey(rc.Proxy), dialTimeout: rc.DialTimeout, customIPs: serializeTransportCacheList(rc.CustomIP), customDNS: serializeTransportCacheList(rc.CustomDNS), tlsServerName: effectiveTLSServerName(rc), } if rc.TLSConfig != nil { key.skipVerify = rc.TLSConfig.InsecureSkipVerify } return key, true } func normalizeProxyCacheKey(proxy string) string { if proxy == "" { return "" } proxyURL, err := parseProxyURL(proxy) if err != nil { return "\x00invalid:" + proxy } return proxyURL.String() } func serializeTransportCacheList(values []string) string { if len(values) == 0 { return "" } var builder strings.Builder for _, value := range values { builder.WriteString(value) builder.WriteByte(0) } return builder.String() } func effectiveTLSServerName(rc *RequestContext) string { if rc == nil { return "" } if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" { return rc.TLSConfig.ServerName } return rc.TLSServerName } // buildDynamicTransport 构建动态 Transport func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport { t.ensureBase() t.mu.RLock() baseTransport := t.base t.mu.RUnlock() return buildDynamicTransportFromBase(baseTransport, rc, traceState) } func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport { transport := baseTransport.Clone() // 应用 TLS 配置(即使为 nil 也要检查 SkipVerify) if rc.TLSConfig != nil { transport.TLSClientConfig = rc.TLSConfig } // 应用代理配置 if rc.Proxy != "" { proxyURL, err := parseProxyURL(rc.Proxy) if err != nil { transport.Proxy = func(*http.Request) (*url.URL, error) { return nil, err } } else { transport.Proxy = http.ProxyURL(proxyURL) } } // 应用自定义 Dial 函数 if rc.DialFn != nil { if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) { dialFn := rc.DialFn transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { if traceState.hooks.ConnectStart != nil { traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr}) } conn, err := dialFn(ctx, network, addr) if traceState.hooks.ConnectDone != nil { traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err}) } return conn, err } } else { transport.DialContext = rc.DialFn } } else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil { // 使用默认 Dial 函数(会从 context 读取配置) transport.DialContext = defaultDialFunc transport.DialTLSContext = defaultDialTLSFunc } return transport } // Base 获取基础 Transport func (t *Transport) Base() *http.Transport { t.mu.RLock() defer t.mu.RUnlock() return t.base } // SetBase 设置基础 Transport func (t *Transport) SetBase(base *http.Transport) { t.mu.Lock() t.base = base t.resetDynamicTransportCacheLocked() t.mu.Unlock() } func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) { if req == nil || req.URL == nil || reqCtx == nil { return req, reqCtx, nil, nil } if reqCtx.Proxy == "" || reqCtx.DialFn != nil { return req, reqCtx, nil, nil } if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil { return req, reqCtx, nil, nil } host := req.URL.Hostname() if host == "" { return req, reqCtx, nil, nil } targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState) if err != nil { return nil, nil, nil, err } if len(targetAddrs) == 0 { return req, reqCtx, nil, nil } execReqCtx := *reqCtx execReqCtx.CustomIP = nil execReqCtx.CustomDNS = nil execReqCtx.LookupIPFn = nil if req.URL.Scheme == "https" { execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host) if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil { execReqCtx.TLSConfigCacheable = true } } execCtx := clearTargetResolutionContext(req.Context()) execReq := req.Clone(execCtx) execReq.Host = req.Host if len(targetAddrs) == 1 { execReq.URL.Host = targetAddrs[0] return execReq, &execReqCtx, nil, nil } return execReq, &execReqCtx, targetAddrs, nil } func clearTargetResolutionContext(ctx context.Context) context.Context { if v := ctx.Value(ctxKeyRequestContext); v != nil { if rc, ok := v.(*RequestContext); ok && rc != nil { cloned := cloneRequestContext(rc) cloned.CustomIP = nil cloned.CustomDNS = nil cloned.LookupIPFn = nil ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned) } } ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil)) ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil)) ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil)) return ctx } func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config { if serverName == "" { return cfg } if cfg != nil { if cfg.ServerName != "" { return cfg } cloned := cfg.Clone() cloned.ServerName = serverName return cloned } return &tls.Config{ ServerName: serverName, NextProtos: []string{"h2", "http/1.1"}, } } func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) { if rt == nil || baseReq == nil || len(targetAddrs) == 0 { return rt.RoundTrip(baseReq) } if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 { targetAddrs = targetAddrs[:1] } var lastErr error for _, targetAddr := range targetAddrs { attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr) if err != nil { return nil, err } resp, err := rt.RoundTrip(attemptReq) if err == nil { return resp, nil } lastErr = err } return nil, lastErr } func requestAllowsResolvedTargetFallback(req *http.Request) bool { if req == nil { return false } if !isIdempotentMethod(req.Method) { return false } if req.Body == nil || req.Body == http.NoBody { return true } return req.GetBody != nil } func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) { req := baseReq.Clone(baseReq.Context()) switch { case baseReq.Body == nil || baseReq.Body == http.NoBody: req.Body = baseReq.Body case baseReq.GetBody != nil: body, err := baseReq.GetBody() if err != nil { return nil, wrapError(err, "clone request body for resolved target") } req.Body = body default: req.Body = baseReq.Body } req.URL.Host = targetAddr req.Host = baseReq.Host return req, nil }