149 lines
3.1 KiB
Go
149 lines
3.1 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"fmt"
|
||
"net"
|
||
"time"
|
||
)
|
||
|
||
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
||
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 提取配置
|
||
reqCtx := getRequestContext(ctx)
|
||
|
||
dialTimeout := reqCtx.DialTimeout
|
||
if dialTimeout == 0 {
|
||
dialTimeout = DefaultDialTimeout
|
||
}
|
||
|
||
timeout := reqCtx.Timeout
|
||
if timeout == 0 {
|
||
timeout = DefaultTimeout
|
||
}
|
||
|
||
// 解析地址
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
return nil, wrapError(err, "split host port")
|
||
}
|
||
|
||
// 获取 IP 地址列表
|
||
var addrs []string
|
||
|
||
// 优先级1:直接指定的 IP
|
||
if len(reqCtx.CustomIP) > 0 {
|
||
for _, ip := range reqCtx.CustomIP {
|
||
addrs = append(addrs, net.JoinHostPort(ip, port))
|
||
}
|
||
} else {
|
||
// 优先级2:DNS 解析
|
||
var ipAddrs []net.IPAddr
|
||
|
||
// 使用自定义解析函数
|
||
if reqCtx.LookupIPFn != nil {
|
||
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
|
||
} else if len(reqCtx.CustomDNS) > 0 {
|
||
// 使用自定义 DNS 服务器
|
||
resolver := &net.Resolver{
|
||
PreferGo: true,
|
||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||
var lastErr error
|
||
for _, dnsServer := range reqCtx.CustomDNS {
|
||
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
return conn, nil
|
||
}
|
||
return nil, lastErr
|
||
},
|
||
}
|
||
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
|
||
} else {
|
||
// 使用默认解析器
|
||
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, wrapError(err, "lookup ip")
|
||
}
|
||
|
||
for _, ipAddr := range ipAddrs {
|
||
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
|
||
}
|
||
}
|
||
|
||
// 尝试连接所有地址
|
||
var lastErr error
|
||
for _, addr := range addrs {
|
||
conn, err := net.DialTimeout(network, addr, dialTimeout)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
// 设置总超时
|
||
if timeout > 0 {
|
||
conn.SetDeadline(time.Now().Add(timeout))
|
||
}
|
||
|
||
return conn, nil
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, wrapError(lastErr, "dial all addresses failed")
|
||
}
|
||
|
||
return nil, fmt.Errorf("no addresses to dial")
|
||
}
|
||
|
||
// defaultDialTLSFunc 默认 TLS Dial 函数
|
||
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 先建立 TCP 连接
|
||
conn, err := defaultDialFunc(ctx, network, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 提取 TLS 配置
|
||
reqCtx := getRequestContext(ctx)
|
||
tlsConfig := reqCtx.TLSConfig
|
||
if tlsConfig == nil {
|
||
tlsConfig = &tls.Config{}
|
||
}
|
||
|
||
// 执行 TLS 握手
|
||
tlsConn := tls.Client(conn, tlsConfig)
|
||
if err := tlsConn.Handshake(); err != nil {
|
||
conn.Close()
|
||
return nil, wrapError(err, "tls handshake")
|
||
}
|
||
|
||
return tlsConn, nil
|
||
}
|
||
|
||
/*
|
||
// defaultProxyFunc 默认代理函数
|
||
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
|
||
if req == nil {
|
||
return nil, fmt.Errorf("request is nil")
|
||
}
|
||
|
||
reqCtx := getRequestContext(req.Context())
|
||
if reqCtx.Proxy == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
proxyURL, err := url.Parse(reqCtx.Proxy)
|
||
if err != nil {
|
||
return nil, wrapError(err, "parse proxy url")
|
||
}
|
||
|
||
return proxyURL, nil
|
||
}
|
||
|
||
*/
|