package starnet import ( "context" "crypto/tls" "net" "net/http" "time" ) // contextKey 私有的 context key 类型(防止冲突) type contextKey int const ( ctxKeyTransport contextKey = iota ctxKeyTLSConfig ctxKeyProxy ctxKeyCustomIP ctxKeyCustomDNS ctxKeyDialTimeout ctxKeyTimeout ctxKeyLookupIP ctxKeyDialFunc ) // RequestContext 从 context 中提取的请求配置 type RequestContext struct { Transport *http.Transport TLSConfig *tls.Config Proxy string CustomIP []string CustomDNS []string DialTimeout time.Duration Timeout time.Duration LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) DialFn func(ctx context.Context, network, addr string) (net.Conn, error) } // getRequestContext 从 context 中提取请求配置 func getRequestContext(ctx context.Context) *RequestContext { rc := &RequestContext{} if v := ctx.Value(ctxKeyTransport); v != nil { rc.Transport, _ = v.(*http.Transport) } if v := ctx.Value(ctxKeyTLSConfig); v != nil { rc.TLSConfig, _ = v.(*tls.Config) } if v := ctx.Value(ctxKeyProxy); v != nil { rc.Proxy, _ = v.(string) } if v := ctx.Value(ctxKeyCustomIP); v != nil { rc.CustomIP, _ = v.([]string) } if v := ctx.Value(ctxKeyCustomDNS); v != nil { rc.CustomDNS, _ = v.([]string) } if v := ctx.Value(ctxKeyDialTimeout); v != nil { rc.DialTimeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyTimeout); v != nil { rc.Timeout, _ = v.(time.Duration) } if v := ctx.Value(ctxKeyLookupIP); v != nil { rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) } if v := ctx.Value(ctxKeyDialFunc); v != nil { rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) } return rc } // needsDynamicTransport 判断是否需要动态 Transport func needsDynamicTransport(rc *RequestContext) bool { return rc.Transport != nil || rc.TLSConfig != nil || rc.Proxy != "" || rc.DialFn != nil || (rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) || (rc.Timeout > 0 && rc.Timeout != DefaultTimeout) || len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.LookupIPFn != nil } // injectRequestConfig 将请求配置注入到 context func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context { execCtx := ctx // 处理 TLS 配置 var tlsConfig *tls.Config if config.TLS.Config != nil { tlsConfig = config.TLS.Config.Clone() if config.TLS.SkipVerify { tlsConfig.InsecureSkipVerify = true } } else if config.TLS.SkipVerify { tlsConfig = &tls.Config{ NextProtos: []string{"h2", "http/1.1"}, InsecureSkipVerify: true, } } if tlsConfig != nil { execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig) } // 注入代理 if config.Network.Proxy != "" { execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy) } // 注入自定义 IP if len(config.DNS.CustomIP) > 0 { execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP) } // 注入自定义 DNS if len(config.DNS.CustomDNS) > 0 { execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS) } // 总是注入 DialTimeout 和 Timeout(与原始代码一致) if config.Network.DialTimeout > 0 { execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout) } if config.Network.Timeout > 0 { execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout) } // 注入 DNS 解析函数 if config.DNS.LookupFunc != nil { execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc) } // 注入 Dial 函数 if config.Network.DialFunc != nil { execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc) } // 注入自定义 Transport if config.CustomTransport && config.Transport != nil { execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport) } return execCtx }