98 lines
2.2 KiB
Go
98 lines
2.2 KiB
Go
package starnet
|
||
|
||
import (
|
||
"net/http"
|
||
"net/url"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// Transport 自定义 Transport(支持请求级配置)
|
||
type Transport struct {
|
||
base *http.Transport
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// RoundTrip 实现 http.RoundTripper 接口
|
||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
// 确保 base 已初始化
|
||
if t.base == nil {
|
||
t.mu.Lock()
|
||
if t.base == nil {
|
||
t.base = &http.Transport{
|
||
ForceAttemptHTTP2: true,
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
}
|
||
}
|
||
t.mu.Unlock()
|
||
}
|
||
|
||
// 提取请求级别的配置
|
||
reqCtx := getRequestContext(req.Context())
|
||
|
||
// 优先级1:完全自定义的 transport
|
||
if reqCtx.Transport != nil {
|
||
return reqCtx.Transport.RoundTrip(req)
|
||
}
|
||
|
||
// 优先级2:需要动态配置
|
||
if needsDynamicTransport(reqCtx) {
|
||
dynamicTransport := t.buildDynamicTransport(reqCtx)
|
||
return dynamicTransport.RoundTrip(req)
|
||
}
|
||
|
||
// 优先级3:使用基础 transport
|
||
t.mu.RLock()
|
||
defer t.mu.RUnlock()
|
||
return t.base.RoundTrip(req)
|
||
}
|
||
|
||
// buildDynamicTransport 构建动态 Transport
|
||
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
|
||
t.mu.RLock()
|
||
transport := t.base.Clone()
|
||
t.mu.RUnlock()
|
||
|
||
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
||
if rc.TLSConfig != nil {
|
||
transport.TLSClientConfig = rc.TLSConfig
|
||
}
|
||
|
||
// 应用代理配置
|
||
if rc.Proxy != "" {
|
||
proxyURL, err := url.Parse(rc.Proxy)
|
||
if err == nil {
|
||
transport.Proxy = http.ProxyURL(proxyURL)
|
||
}
|
||
}
|
||
|
||
// 应用自定义 Dial 函数
|
||
if rc.DialFn != nil {
|
||
transport.DialContext = rc.DialFn
|
||
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
||
// 使用默认 Dial 函数(会从 context 读取配置)
|
||
transport.DialContext = defaultDialFunc
|
||
transport.DialTLSContext = defaultDialTLSFunc
|
||
}
|
||
|
||
return transport
|
||
}
|
||
|
||
// Base 获取基础 Transport
|
||
func (t *Transport) Base() *http.Transport {
|
||
t.mu.RLock()
|
||
defer t.mu.RUnlock()
|
||
return t.base
|
||
}
|
||
|
||
// SetBase 设置基础 Transport
|
||
func (t *Transport) SetBase(base *http.Transport) {
|
||
t.mu.Lock()
|
||
t.base = base
|
||
t.mu.Unlock()
|
||
}
|