starnet/dialer.go
2026-03-08 20:19:40 +08:00

149 lines
3.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"time"
)
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
timeout := reqCtx.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
// 获取 IP 地址列表
var addrs []string
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
}
}
// 尝试连接所有地址
var lastErr error
for _, addr := range addrs {
conn, err := net.DialTimeout(network, addr, dialTimeout)
if err != nil {
lastErr = err
continue
}
// 设置总超时
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
// 执行 TLS 握手
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, wrapError(err, "tls handshake")
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/