starnet/context.go

189 lines
4.7 KiB
Go
Raw Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyTLSConfigCacheable
ctxKeyTLSServerName
2026-03-08 20:19:40 +08:00
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
ctxKeyRequestContext
2026-03-08 20:19:40 +08:00
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
TLSConfigCacheable bool
TLSServerName string
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
2026-03-08 20:19:40 +08:00
}
var emptyRequestContext = &RequestContext{}
2026-03-08 20:19:40 +08:00
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
return rc
}
}
2026-03-08 20:19:40 +08:00
var rc *RequestContext
ensure := func() *RequestContext {
if rc == nil {
rc = &RequestContext{}
}
return rc
}
2026-03-08 20:19:40 +08:00
if v := ctx.Value(ctxKeyTransport); v != nil {
ensure().Transport, _ = v.(*http.Transport)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
ensure().TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
ensure().TLSConfigCacheable, _ = v.(bool)
}
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
ensure().TLSServerName, _ = v.(string)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyProxy); v != nil {
ensure().Proxy, _ = v.(string)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
ensure().CustomIP, _ = v.([]string)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
ensure().CustomDNS, _ = v.([]string)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
ensure().DialTimeout, _ = v.(time.Duration)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
ensure().Timeout, _ = v.(time.Duration)
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
2026-03-08 20:19:40 +08:00
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
if rc == nil {
return emptyRequestContext
2026-03-08 20:19:40 +08:00
}
return rc
}
func cloneRequestContext(rc *RequestContext) *RequestContext {
if rc == nil {
return nil
}
cloned := *rc
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
return &cloned
}
2026-03-08 20:19:40 +08:00
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
if rc == nil {
return false
}
2026-03-08 20:19:40 +08:00
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
if config == nil {
return nil
}
rc := &RequestContext{
DialTimeout: config.Network.DialTimeout,
Timeout: config.Network.Timeout,
}
2026-03-08 20:19:40 +08:00
// 处理 TLS 配置
var tlsConfig *tls.Config
tlsConfigCacheable := false
2026-03-08 20:19:40 +08:00
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
} else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
2026-03-08 20:19:40 +08:00
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
2026-03-08 20:19:40 +08:00
}
tlsConfigCacheable = true
2026-03-08 20:19:40 +08:00
}
if config.TLS.SkipVerify && tlsConfig != nil {
tlsConfig.InsecureSkipVerify = true
2026-03-08 20:19:40 +08:00
}
if config.TLS.ServerName != "" && tlsConfig != nil {
tlsConfig.ServerName = config.TLS.ServerName
2026-03-08 20:19:40 +08:00
}
if tlsConfig != nil {
rc.TLSConfig = tlsConfig
rc.TLSConfigCacheable = tlsConfigCacheable
2026-03-08 20:19:40 +08:00
}
if config.TLS.ServerName != "" {
rc.TLSServerName = config.TLS.ServerName
} else if defaultTLSServerName != "" {
rc.TLSServerName = defaultTLSServerName
2026-03-08 20:19:40 +08:00
}
rc.Proxy = config.Network.Proxy
rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
rc.LookupIPFn = config.DNS.LookupFunc
rc.DialFn = config.Network.DialFunc
2026-03-08 20:19:40 +08:00
if config.CustomTransport && config.Transport != nil {
rc.Transport = config.Transport
2026-03-08 20:19:40 +08:00
}
if !needsDynamicTransport(rc) {
return nil
2026-03-08 20:19:40 +08:00
}
return rc
}
2026-03-08 20:19:40 +08:00
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
rc := buildRequestContext(config, defaultTLSServerName)
if rc == nil {
return ctx
}
return context.WithValue(ctx, ctxKeyRequestContext, rc)
2026-03-08 20:19:40 +08:00
}