150 lines
3.9 KiB
Go
150 lines
3.9 KiB
Go
|
|
package starnet
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"crypto/tls"
|
|||
|
|
"net"
|
|||
|
|
"net/http"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// contextKey 私有的 context key 类型(防止冲突)
|
|||
|
|
type contextKey int
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
ctxKeyTransport contextKey = iota
|
|||
|
|
ctxKeyTLSConfig
|
|||
|
|
ctxKeyProxy
|
|||
|
|
ctxKeyCustomIP
|
|||
|
|
ctxKeyCustomDNS
|
|||
|
|
ctxKeyDialTimeout
|
|||
|
|
ctxKeyTimeout
|
|||
|
|
ctxKeyLookupIP
|
|||
|
|
ctxKeyDialFunc
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// RequestContext 从 context 中提取的请求配置
|
|||
|
|
type RequestContext struct {
|
|||
|
|
Transport *http.Transport
|
|||
|
|
TLSConfig *tls.Config
|
|||
|
|
Proxy string
|
|||
|
|
CustomIP []string
|
|||
|
|
CustomDNS []string
|
|||
|
|
DialTimeout time.Duration
|
|||
|
|
Timeout time.Duration
|
|||
|
|
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
|||
|
|
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getRequestContext 从 context 中提取请求配置
|
|||
|
|
func getRequestContext(ctx context.Context) *RequestContext {
|
|||
|
|
rc := &RequestContext{}
|
|||
|
|
|
|||
|
|
if v := ctx.Value(ctxKeyTransport); v != nil {
|
|||
|
|
rc.Transport, _ = v.(*http.Transport)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
|||
|
|
rc.TLSConfig, _ = v.(*tls.Config)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyProxy); v != nil {
|
|||
|
|
rc.Proxy, _ = v.(string)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
|||
|
|
rc.CustomIP, _ = v.([]string)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
|||
|
|
rc.CustomDNS, _ = v.([]string)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
|||
|
|
rc.DialTimeout, _ = v.(time.Duration)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
|||
|
|
rc.Timeout, _ = v.(time.Duration)
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
|||
|
|
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
|||
|
|
}
|
|||
|
|
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
|||
|
|
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return rc
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// needsDynamicTransport 判断是否需要动态 Transport
|
|||
|
|
func needsDynamicTransport(rc *RequestContext) bool {
|
|||
|
|
return rc.Transport != nil ||
|
|||
|
|
rc.TLSConfig != nil ||
|
|||
|
|
rc.Proxy != "" ||
|
|||
|
|
rc.DialFn != nil ||
|
|||
|
|
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
|||
|
|
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
|
|||
|
|
len(rc.CustomIP) > 0 ||
|
|||
|
|
len(rc.CustomDNS) > 0 ||
|
|||
|
|
rc.LookupIPFn != nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// injectRequestConfig 将请求配置注入到 context
|
|||
|
|
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
|
|||
|
|
execCtx := ctx
|
|||
|
|
|
|||
|
|
// 处理 TLS 配置
|
|||
|
|
var tlsConfig *tls.Config
|
|||
|
|
|
|||
|
|
if config.TLS.Config != nil {
|
|||
|
|
tlsConfig = config.TLS.Config.Clone()
|
|||
|
|
if config.TLS.SkipVerify {
|
|||
|
|
tlsConfig.InsecureSkipVerify = true
|
|||
|
|
}
|
|||
|
|
} else if config.TLS.SkipVerify {
|
|||
|
|
tlsConfig = &tls.Config{
|
|||
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|||
|
|
InsecureSkipVerify: true,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if tlsConfig != nil {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入代理
|
|||
|
|
if config.Network.Proxy != "" {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入自定义 IP
|
|||
|
|
if len(config.DNS.CustomIP) > 0 {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入自定义 DNS
|
|||
|
|
if len(config.DNS.CustomDNS) > 0 {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 总是注入 DialTimeout 和 Timeout(与原始代码一致)
|
|||
|
|
if config.Network.DialTimeout > 0 {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
|||
|
|
}
|
|||
|
|
if config.Network.Timeout > 0 {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入 DNS 解析函数
|
|||
|
|
if config.DNS.LookupFunc != nil {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入 Dial 函数
|
|||
|
|
if config.Network.DialFunc != nil {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入自定义 Transport
|
|||
|
|
if config.CustomTransport && config.Transport != nil {
|
|||
|
|
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return execCtx
|
|||
|
|
}
|