150 lines
3.9 KiB
Go
150 lines
3.9 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"net"
|
||
"net/http"
|
||
"time"
|
||
)
|
||
|
||
// contextKey 私有的 context key 类型(防止冲突)
|
||
type contextKey int
|
||
|
||
const (
|
||
ctxKeyTransport contextKey = iota
|
||
ctxKeyTLSConfig
|
||
ctxKeyProxy
|
||
ctxKeyCustomIP
|
||
ctxKeyCustomDNS
|
||
ctxKeyDialTimeout
|
||
ctxKeyTimeout
|
||
ctxKeyLookupIP
|
||
ctxKeyDialFunc
|
||
)
|
||
|
||
// RequestContext 从 context 中提取的请求配置
|
||
type RequestContext struct {
|
||
Transport *http.Transport
|
||
TLSConfig *tls.Config
|
||
Proxy string
|
||
CustomIP []string
|
||
CustomDNS []string
|
||
DialTimeout time.Duration
|
||
Timeout time.Duration
|
||
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
||
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||
}
|
||
|
||
// getRequestContext 从 context 中提取请求配置
|
||
func getRequestContext(ctx context.Context) *RequestContext {
|
||
rc := &RequestContext{}
|
||
|
||
if v := ctx.Value(ctxKeyTransport); v != nil {
|
||
rc.Transport, _ = v.(*http.Transport)
|
||
}
|
||
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
||
rc.TLSConfig, _ = v.(*tls.Config)
|
||
}
|
||
if v := ctx.Value(ctxKeyProxy); v != nil {
|
||
rc.Proxy, _ = v.(string)
|
||
}
|
||
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
||
rc.CustomIP, _ = v.([]string)
|
||
}
|
||
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
||
rc.CustomDNS, _ = v.([]string)
|
||
}
|
||
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
||
rc.DialTimeout, _ = v.(time.Duration)
|
||
}
|
||
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
||
rc.Timeout, _ = v.(time.Duration)
|
||
}
|
||
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
||
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
||
}
|
||
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
||
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
||
}
|
||
|
||
return rc
|
||
}
|
||
|
||
// needsDynamicTransport 判断是否需要动态 Transport
|
||
func needsDynamicTransport(rc *RequestContext) bool {
|
||
return rc.Transport != nil ||
|
||
rc.TLSConfig != nil ||
|
||
rc.Proxy != "" ||
|
||
rc.DialFn != nil ||
|
||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
|
||
len(rc.CustomIP) > 0 ||
|
||
len(rc.CustomDNS) > 0 ||
|
||
rc.LookupIPFn != nil
|
||
}
|
||
|
||
// injectRequestConfig 将请求配置注入到 context
|
||
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
|
||
execCtx := ctx
|
||
|
||
// 处理 TLS 配置
|
||
var tlsConfig *tls.Config
|
||
|
||
if config.TLS.Config != nil {
|
||
tlsConfig = config.TLS.Config.Clone()
|
||
if config.TLS.SkipVerify {
|
||
tlsConfig.InsecureSkipVerify = true
|
||
}
|
||
} else if config.TLS.SkipVerify {
|
||
tlsConfig = &tls.Config{
|
||
NextProtos: []string{"h2", "http/1.1"},
|
||
InsecureSkipVerify: true,
|
||
}
|
||
}
|
||
|
||
if tlsConfig != nil {
|
||
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
|
||
}
|
||
|
||
// 注入代理
|
||
if config.Network.Proxy != "" {
|
||
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
|
||
}
|
||
|
||
// 注入自定义 IP
|
||
if len(config.DNS.CustomIP) > 0 {
|
||
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
|
||
}
|
||
|
||
// 注入自定义 DNS
|
||
if len(config.DNS.CustomDNS) > 0 {
|
||
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
||
}
|
||
|
||
// 总是注入 DialTimeout 和 Timeout(与原始代码一致)
|
||
if config.Network.DialTimeout > 0 {
|
||
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
||
}
|
||
if config.Network.Timeout > 0 {
|
||
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
|
||
}
|
||
|
||
// 注入 DNS 解析函数
|
||
if config.DNS.LookupFunc != nil {
|
||
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
|
||
}
|
||
|
||
// 注入 Dial 函数
|
||
if config.Network.DialFunc != nil {
|
||
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
|
||
}
|
||
|
||
// 注入自定义 Transport
|
||
if config.CustomTransport && config.Transport != nil {
|
||
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
|
||
}
|
||
|
||
return execCtx
|
||
}
|