starnet/context.go
2026-03-08 20:19:40 +08:00

150 lines
3.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{}
if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
return rc
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
execCtx := ctx
// 处理 TLS 配置
var tlsConfig *tls.Config
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
}
// 注入代理
if config.Network.Proxy != "" {
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
}
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout 和 Timeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
if config.Network.Timeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
}
return execCtx
}