starnet/context.go

150 lines
3.9 KiB
Go
Raw Permalink Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{}
if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
return rc
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
execCtx := ctx
// 处理 TLS 配置
var tlsConfig *tls.Config
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
}
// 注入代理
if config.Network.Proxy != "" {
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
}
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout 和 Timeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
if config.Network.Timeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
}
return execCtx
}