starnet/transport.go

417 lines
11 KiB
Go
Raw Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"net"
2026-03-08 20:19:40 +08:00
"net/http"
"net/url"
"strings"
2026-03-08 20:19:40 +08:00
"sync"
"time"
)
const dynamicTransportCacheMaxEntries = 64
type dynamicTransportCacheKey struct {
proxyKey string
dialTimeout time.Duration
customIPs string
customDNS string
tlsServerName string
skipVerify bool
}
2026-03-08 20:19:40 +08:00
// Transport 自定义 Transport支持请求级配置
type Transport struct {
base *http.Transport
dynamicCache map[dynamicTransportCacheKey]*http.Transport
dynamicCacheOrder []dynamicTransportCacheKey
mu sync.RWMutex
2026-03-08 20:19:40 +08:00
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
t.ensureBase()
2026-03-08 20:19:40 +08:00
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
traceState := getTraceState(req.Context())
execReq := req
execReqCtx := reqCtx
var targetAddrs []string
2026-03-08 20:19:40 +08:00
// 优先级1完全自定义的 transport
if execReqCtx.Transport != nil {
return execReqCtx.Transport.RoundTrip(execReq)
}
var err error
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
if err != nil {
return nil, err
2026-03-08 20:19:40 +08:00
}
// 优先级2需要动态配置
if needsDynamicTransport(execReqCtx) {
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
}
return dynamicTransport.RoundTrip(execReq)
2026-03-08 20:19:40 +08:00
}
// 优先级3使用基础 transport
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
}
return baseTransport.RoundTrip(execReq)
}
func newBaseHTTPTransport() *http.Transport {
return &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
func (t *Transport) ensureBase() {
if t.base != nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
}
func (t *Transport) ensureBaseLocked() {
if t.base == nil {
t.base = newBaseHTTPTransport()
}
}
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
if key, ok := newDynamicTransportCacheKey(rc); ok {
return t.getOrCreateCachedDynamicTransport(key, rc)
}
return t.buildDynamicTransport(rc, traceState)
}
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
t.mu.RLock()
if transport := t.dynamicCache[key]; transport != nil {
t.mu.RUnlock()
return transport
}
t.mu.RUnlock()
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
if transport := t.dynamicCache[key]; transport != nil {
return transport
}
transport := buildDynamicTransportFromBase(t.base, rc, nil)
if t.dynamicCache == nil {
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
}
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
oldestKey := t.dynamicCacheOrder[0]
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
oldest.CloseIdleConnections()
delete(t.dynamicCache, oldestKey)
}
}
t.dynamicCache[key] = transport
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
return transport
}
func (t *Transport) resetDynamicTransportCacheLocked() {
for _, key := range t.dynamicCacheOrder {
if transport := t.dynamicCache[key]; transport != nil {
transport.CloseIdleConnections()
}
}
t.dynamicCache = nil
t.dynamicCacheOrder = nil
}
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
if rc == nil {
return dynamicTransportCacheKey{}, false
}
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
return dynamicTransportCacheKey{}, false
}
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
return dynamicTransportCacheKey{}, false
}
key := dynamicTransportCacheKey{
proxyKey: normalizeProxyCacheKey(rc.Proxy),
dialTimeout: rc.DialTimeout,
customIPs: serializeTransportCacheList(rc.CustomIP),
customDNS: serializeTransportCacheList(rc.CustomDNS),
tlsServerName: effectiveTLSServerName(rc),
}
if rc.TLSConfig != nil {
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
}
return key, true
}
func normalizeProxyCacheKey(proxy string) string {
if proxy == "" {
return ""
}
proxyURL, err := parseProxyURL(proxy)
if err != nil {
return "\x00invalid:" + proxy
}
return proxyURL.String()
}
func serializeTransportCacheList(values []string) string {
if len(values) == 0 {
return ""
}
var builder strings.Builder
for _, value := range values {
builder.WriteString(value)
builder.WriteByte(0)
}
return builder.String()
}
func effectiveTLSServerName(rc *RequestContext) string {
if rc == nil {
return ""
}
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
return rc.TLSConfig.ServerName
}
return rc.TLSServerName
2026-03-08 20:19:40 +08:00
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
t.ensureBase()
2026-03-08 20:19:40 +08:00
t.mu.RLock()
baseTransport := t.base
2026-03-08 20:19:40 +08:00
t.mu.RUnlock()
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
}
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
transport := baseTransport.Clone()
2026-03-08 20:19:40 +08:00
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
transport.TLSClientConfig = rc.TLSConfig
}
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := parseProxyURL(rc.Proxy)
if err != nil {
transport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, err
}
} else {
2026-03-08 20:19:40 +08:00
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
dialFn := rc.DialFn
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if traceState.hooks.ConnectStart != nil {
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
conn, err := dialFn(ctx, network, addr)
if traceState.hooks.ConnectDone != nil {
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
return conn, err
}
} else {
transport.DialContext = rc.DialFn
}
2026-03-08 20:19:40 +08:00
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
transport.DialTLSContext = defaultDialTLSFunc
}
return transport
}
// Base 获取基础 Transport
func (t *Transport) Base() *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base
}
// SetBase 设置基础 Transport
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.resetDynamicTransportCacheLocked()
2026-03-08 20:19:40 +08:00
t.mu.Unlock()
}
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
if req == nil || req.URL == nil || reqCtx == nil {
return req, reqCtx, nil, nil
}
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
return req, reqCtx, nil, nil
}
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
return req, reqCtx, nil, nil
}
host := req.URL.Hostname()
if host == "" {
return req, reqCtx, nil, nil
}
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
if err != nil {
return nil, nil, nil, err
}
if len(targetAddrs) == 0 {
return req, reqCtx, nil, nil
}
execReqCtx := *reqCtx
execReqCtx.CustomIP = nil
execReqCtx.CustomDNS = nil
execReqCtx.LookupIPFn = nil
if req.URL.Scheme == "https" {
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
execReqCtx.TLSConfigCacheable = true
}
}
execCtx := clearTargetResolutionContext(req.Context())
execReq := req.Clone(execCtx)
execReq.Host = req.Host
if len(targetAddrs) == 1 {
execReq.URL.Host = targetAddrs[0]
return execReq, &execReqCtx, nil, nil
}
return execReq, &execReqCtx, targetAddrs, nil
}
func clearTargetResolutionContext(ctx context.Context) context.Context {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
cloned := cloneRequestContext(rc)
cloned.CustomIP = nil
cloned.CustomDNS = nil
cloned.LookupIPFn = nil
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
}
}
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
return ctx
}
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
if serverName == "" {
return cfg
}
if cfg != nil {
if cfg.ServerName != "" {
return cfg
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
return &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
}
}
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
return rt.RoundTrip(baseReq)
}
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
targetAddrs = targetAddrs[:1]
}
var lastErr error
for _, targetAddr := range targetAddrs {
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
if err != nil {
return nil, err
}
resp, err := rt.RoundTrip(attemptReq)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
if req == nil {
return false
}
if !isIdempotentMethod(req.Method) {
return false
}
if req.Body == nil || req.Body == http.NoBody {
return true
}
return req.GetBody != nil
}
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
req := baseReq.Clone(baseReq.Context())
switch {
case baseReq.Body == nil || baseReq.Body == http.NoBody:
req.Body = baseReq.Body
case baseReq.GetBody != nil:
body, err := baseReq.GetBody()
if err != nil {
return nil, wrapError(err, "clone request body for resolved target")
}
req.Body = body
default:
req.Body = baseReq.Body
}
req.URL.Host = targetAddr
req.Host = baseReq.Host
return req, nil
}