starnet/dialer.go

160 lines
3.5 KiB
Go
Raw Permalink Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"time"
)
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
timeout := reqCtx.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
// 获取 IP 地址列表
var addrs []string
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
}
}
// 尝试连接所有地址
var lastErr error
for _, addr := range addrs {
conn, err := net.DialTimeout(network, addr, dialTimeout)
if err != nil {
lastErr = err
continue
}
// 设置总超时
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify自动设置
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// addr 可能没有端口,直接用 addr
host = addr
}
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = host
}
2026-03-08 20:19:40 +08:00
// 执行 TLS 握手
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, wrapError(err, "tls handshake")
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/