package starnet import ( "context" "crypto/tls" "net" "net/http" "time" ) // contextKey 私有的 context key 类型(防止冲突) type contextKey int const ( ctxKeyTransport contextKey = iota ctxKeyTLSConfig ctxKeyTLSConfigCacheable ctxKeyTLSServerName ctxKeyProxy ctxKeyCustomIP ctxKeyCustomDNS ctxKeyDialTimeout ctxKeyTimeout ctxKeyLookupIP ctxKeyDialFunc ctxKeyRequestContext ) // RequestContext 从 context 中提取的请求配置 type RequestContext struct { Transport *http.Transport TLSConfig *tls.Config TLSConfigCacheable bool TLSServerName string Proxy string CustomIP []string CustomDNS []string DialTimeout time.Duration Timeout time.Duration LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) DialFn func(ctx context.Context, network, addr string) (net.Conn, error) } var emptyRequestContext = &RequestContext{} // getRequestContext 从 context 中提取请求配置 func getRequestContext(ctx context.Context) *RequestContext { if v := ctx.Value(ctxKeyRequestContext); v != nil { if rc, ok := v.(*RequestContext); ok && rc != nil { return rc } } var rc *RequestContext ensure := func() *RequestContext { if rc == nil { rc = &RequestContext{} } return rc } if v := ctx.Value(ctxKeyTransport); v != nil { ensure().Transport, _ = v.(*http.Transport) } if v := ctx.Value(ctxKeyTLSConfig); v != nil { ensure().TLSConfig, _ = v.(*tls.Config) } if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil { ensure().TLSConfigCacheable, _ = v.(bool) } if v := ctx.Value(ctxKeyTLSServerName); v != nil { ensure().TLSServerName, _ = v.(string) } if v := ctx.Value(ctxKeyProxy); v != nil { ensure().Proxy, _ = v.(string) } if v := ctx.Value(ctxKeyCustomIP); v != nil { ensure().CustomIP, _ = v.([]string) } if v := ctx.Value(ctxKeyCustomDNS); v != nil { ensure().CustomDNS, _ = v.([]string) } if v := ctx.Value(ctxKeyDialTimeout); v != nil { ensure().DialTimeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyTimeout); v != nil { ensure().Timeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyLookupIP); v != nil { ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) } if v := ctx.Value(ctxKeyDialFunc); v != nil { ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) } if rc == nil { return emptyRequestContext } return rc } func cloneRequestContext(rc *RequestContext) *RequestContext { if rc == nil { return nil } cloned := *rc cloned.CustomIP = cloneStringSlice(rc.CustomIP) cloned.CustomDNS = cloneStringSlice(rc.CustomDNS) return &cloned } // needsDynamicTransport 判断是否需要动态 Transport func needsDynamicTransport(rc *RequestContext) bool { if rc == nil { return false } return rc.Transport != nil || rc.TLSConfig != nil || rc.Proxy != "" || rc.DialFn != nil || (rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) || len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.LookupIPFn != nil } func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext { if config == nil { return nil } rc := &RequestContext{ DialTimeout: config.Network.DialTimeout, Timeout: config.Network.Timeout, } // 处理 TLS 配置 var tlsConfig *tls.Config tlsConfigCacheable := false if config.TLS.Config != nil { tlsConfig = config.TLS.Config.Clone() } else if config.TLS.SkipVerify || config.TLS.ServerName != "" { tlsConfig = &tls.Config{ NextProtos: []string{"h2", "http/1.1"}, } tlsConfigCacheable = true } if config.TLS.SkipVerify && tlsConfig != nil { tlsConfig.InsecureSkipVerify = true } if config.TLS.ServerName != "" && tlsConfig != nil { tlsConfig.ServerName = config.TLS.ServerName } if tlsConfig != nil { rc.TLSConfig = tlsConfig rc.TLSConfigCacheable = tlsConfigCacheable } if config.TLS.ServerName != "" { rc.TLSServerName = config.TLS.ServerName } else if defaultTLSServerName != "" { rc.TLSServerName = defaultTLSServerName } rc.Proxy = config.Network.Proxy rc.CustomIP = cloneStringSlice(config.DNS.CustomIP) rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS) rc.LookupIPFn = config.DNS.LookupFunc rc.DialFn = config.Network.DialFunc if config.CustomTransport && config.Transport != nil { rc.Transport = config.Transport } if !needsDynamicTransport(rc) { return nil } return rc } // injectRequestConfig 将请求配置注入到 context func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context { rc := buildRequestContext(config, defaultTLSServerName) if rc == nil { return ctx } return context.WithValue(ctx, ctxKeyRequestContext, rc) }