starnet/dialer.go
starainrt 732e81316c
fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界
- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
2026-04-19 15:39:51 +08:00

239 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"
)
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
if traceState != nil {
traceState.beginManualDNS()
defer traceState.endManualDNS()
traceState.dnsStart(TraceDNSStartInfo{Host: host})
}
ipAddrs, err := lookup()
if traceState != nil {
traceState.dnsDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), ipAddrs...),
Err: err,
})
}
return ipAddrs, err
}
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var addrs []string
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, joinResolvedHostPort(ip, port))
}
return addrs, nil
}
var (
ipAddrs []net.IPAddr
err error
)
if reqCtx.LookupIPFn != nil {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return reqCtx.LookupIPFn(ctx, host)
})
} else if len(reqCtx.CustomDNS) > 0 {
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return resolver.LookupIPAddr(ctx, host)
})
} else {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
})
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
}
return addrs, nil
}
func joinResolvedHostPort(host, port string) string {
if port == "" {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]"
}
return host
}
return net.JoinHostPort(host, port)
}
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
if err != nil {
return nil, err
}
// 尝试连接所有地址
dialer := &net.Dialer{Timeout: dialTimeout}
var lastErr error
for _, addr := range addrs {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
lastErr = err
continue
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
serverName := tlsConfig.ServerName
if serverName == "" {
serverName = reqCtx.TLSServerName
}
if serverName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
host = addr[:idx]
} else {
host = addr
}
}
serverName = host
}
if serverName != "" && tlsConfig.ServerName != serverName {
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = serverName
}
if traceState != nil {
traceState.markCustomTLS()
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
Network: network,
Addr: addr,
ServerName: serverName,
})
}
// 执行 TLS 握手
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
Err: err,
})
}
conn.Close()
return nil, wrapError(err, "tls handshake")
}
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
ConnectionState: tlsConn.ConnectionState(),
})
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/