package starnet import ( "context" "crypto/tls" "fmt" "net" "time" ) // defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS) func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 提取配置 reqCtx := getRequestContext(ctx) dialTimeout := reqCtx.DialTimeout if dialTimeout == 0 { dialTimeout = DefaultDialTimeout } timeout := reqCtx.Timeout if timeout == 0 { timeout = DefaultTimeout } // 解析地址 host, port, err := net.SplitHostPort(addr) if err != nil { return nil, wrapError(err, "split host port") } // 获取 IP 地址列表 var addrs []string // 优先级1:直接指定的 IP if len(reqCtx.CustomIP) > 0 { for _, ip := range reqCtx.CustomIP { addrs = append(addrs, net.JoinHostPort(ip, port)) } } else { // 优先级2:DNS 解析 var ipAddrs []net.IPAddr // 使用自定义解析函数 if reqCtx.LookupIPFn != nil { ipAddrs, err = reqCtx.LookupIPFn(ctx, host) } else if len(reqCtx.CustomDNS) > 0 { // 使用自定义 DNS 服务器 resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { var lastErr error for _, dnsServer := range reqCtx.CustomDNS { conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53")) if err != nil { lastErr = err continue } return conn, nil } return nil, lastErr }, } ipAddrs, err = resolver.LookupIPAddr(ctx, host) } else { // 使用默认解析器 ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host) } if err != nil { return nil, wrapError(err, "lookup ip") } for _, ipAddr := range ipAddrs { addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port)) } } // 尝试连接所有地址 var lastErr error for _, addr := range addrs { conn, err := net.DialTimeout(network, addr, dialTimeout) if err != nil { lastErr = err continue } // 设置总超时 if timeout > 0 { conn.SetDeadline(time.Now().Add(timeout)) } return conn, nil } if lastErr != nil { return nil, wrapError(lastErr, "dial all addresses failed") } return nil, fmt.Errorf("no addresses to dial") } // defaultDialTLSFunc 默认 TLS Dial 函数 func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 先建立 TCP 连接 conn, err := defaultDialFunc(ctx, network, addr) if err != nil { return nil, err } // 提取 TLS 配置 reqCtx := getRequestContext(ctx) tlsConfig := reqCtx.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{} } // 执行 TLS 握手 tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { conn.Close() return nil, wrapError(err, "tls handshake") } return tlsConn, nil } /* // defaultProxyFunc 默认代理函数 func defaultProxyFunc(req *http.Request) (*url.URL, error) { if req == nil { return nil, fmt.Errorf("request is nil") } reqCtx := getRequestContext(req.Context()) if reqCtx.Proxy == "" { return nil, nil } proxyURL, err := url.Parse(reqCtx.Proxy) if err != nil { return nil, wrapError(err, "parse proxy url") } return proxyURL, nil } */