starnet/dialer.go

239 lines
5.5 KiB
Go
Raw Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
2026-03-08 20:19:40 +08:00
"time"
)
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
if traceState != nil {
traceState.beginManualDNS()
defer traceState.endManualDNS()
traceState.dnsStart(TraceDNSStartInfo{Host: host})
}
ipAddrs, err := lookup()
if traceState != nil {
traceState.dnsDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), ipAddrs...),
Err: err,
})
}
return ipAddrs, err
}
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var addrs []string
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, joinResolvedHostPort(ip, port))
}
return addrs, nil
}
var (
ipAddrs []net.IPAddr
err error
)
if reqCtx.LookupIPFn != nil {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return reqCtx.LookupIPFn(ctx, host)
})
} else if len(reqCtx.CustomDNS) > 0 {
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return resolver.LookupIPAddr(ctx, host)
})
} else {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
})
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
}
return addrs, nil
}
func joinResolvedHostPort(host, port string) string {
if port == "" {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]"
}
return host
}
return net.JoinHostPort(host, port)
}
2026-03-08 20:19:40 +08:00
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
2026-03-08 20:19:40 +08:00
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
if err != nil {
return nil, err
2026-03-08 20:19:40 +08:00
}
// 尝试连接所有地址
dialer := &net.Dialer{Timeout: dialTimeout}
2026-03-08 20:19:40 +08:00
var lastErr error
for _, addr := range addrs {
conn, err := dialer.DialContext(ctx, network, addr)
2026-03-08 20:19:40 +08:00
if err != nil {
lastErr = err
continue
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
2026-03-08 20:19:40 +08:00
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
serverName := tlsConfig.ServerName
if serverName == "" {
serverName = reqCtx.TLSServerName
}
if serverName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
host = addr[:idx]
} else {
host = addr
}
}
serverName = host
}
if serverName != "" && tlsConfig.ServerName != serverName {
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = serverName
}
if traceState != nil {
traceState.markCustomTLS()
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
Network: network,
Addr: addr,
ServerName: serverName,
})
}
2026-03-08 20:19:40 +08:00
// 执行 TLS 握手
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
defer conn.SetDeadline(time.Time{})
}
2026-03-08 20:19:40 +08:00
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
Err: err,
})
}
2026-03-08 20:19:40 +08:00
conn.Close()
return nil, wrapError(err, "tls handshake")
}
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
ConnectionState: tlsConn.ConnectionState(),
})
}
2026-03-08 20:19:40 +08:00
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/