- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
239 lines
5.5 KiB
Go
239 lines
5.5 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"fmt"
|
||
"net"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
|
||
if traceState != nil {
|
||
traceState.beginManualDNS()
|
||
defer traceState.endManualDNS()
|
||
traceState.dnsStart(TraceDNSStartInfo{Host: host})
|
||
}
|
||
ipAddrs, err := lookup()
|
||
if traceState != nil {
|
||
traceState.dnsDone(TraceDNSDoneInfo{
|
||
Addrs: append([]net.IPAddr(nil), ipAddrs...),
|
||
Err: err,
|
||
})
|
||
}
|
||
return ipAddrs, err
|
||
}
|
||
|
||
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
|
||
if reqCtx == nil {
|
||
reqCtx = &RequestContext{}
|
||
}
|
||
|
||
var addrs []string
|
||
|
||
if len(reqCtx.CustomIP) > 0 {
|
||
for _, ip := range reqCtx.CustomIP {
|
||
addrs = append(addrs, joinResolvedHostPort(ip, port))
|
||
}
|
||
return addrs, nil
|
||
}
|
||
|
||
var (
|
||
ipAddrs []net.IPAddr
|
||
err error
|
||
)
|
||
|
||
if reqCtx.LookupIPFn != nil {
|
||
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||
return reqCtx.LookupIPFn(ctx, host)
|
||
})
|
||
} else if len(reqCtx.CustomDNS) > 0 {
|
||
dialTimeout := reqCtx.DialTimeout
|
||
if dialTimeout == 0 {
|
||
dialTimeout = DefaultDialTimeout
|
||
}
|
||
dialer := &net.Dialer{Timeout: dialTimeout}
|
||
resolver := &net.Resolver{
|
||
PreferGo: true,
|
||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||
var lastErr error
|
||
for _, dnsServer := range reqCtx.CustomDNS {
|
||
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
return conn, nil
|
||
}
|
||
return nil, lastErr
|
||
},
|
||
}
|
||
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||
return resolver.LookupIPAddr(ctx, host)
|
||
})
|
||
} else {
|
||
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
||
})
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, wrapError(err, "lookup ip")
|
||
}
|
||
|
||
for _, ipAddr := range ipAddrs {
|
||
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
|
||
}
|
||
return addrs, nil
|
||
}
|
||
|
||
func joinResolvedHostPort(host, port string) string {
|
||
if port == "" {
|
||
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
|
||
return "[" + host + "]"
|
||
}
|
||
return host
|
||
}
|
||
return net.JoinHostPort(host, port)
|
||
}
|
||
|
||
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
||
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 提取配置
|
||
reqCtx := getRequestContext(ctx)
|
||
traceState := getTraceState(ctx)
|
||
|
||
dialTimeout := reqCtx.DialTimeout
|
||
if dialTimeout == 0 {
|
||
dialTimeout = DefaultDialTimeout
|
||
}
|
||
|
||
// 解析地址
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
return nil, wrapError(err, "split host port")
|
||
}
|
||
|
||
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 尝试连接所有地址
|
||
dialer := &net.Dialer{Timeout: dialTimeout}
|
||
var lastErr error
|
||
for _, addr := range addrs {
|
||
conn, err := dialer.DialContext(ctx, network, addr)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
return conn, nil
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, wrapError(lastErr, "dial all addresses failed")
|
||
}
|
||
|
||
return nil, fmt.Errorf("no addresses to dial")
|
||
}
|
||
|
||
// defaultDialTLSFunc 默认 TLS Dial 函数
|
||
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 先建立 TCP 连接
|
||
conn, err := defaultDialFunc(ctx, network, addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 提取 TLS 配置
|
||
reqCtx := getRequestContext(ctx)
|
||
traceState := getTraceState(ctx)
|
||
tlsConfig := reqCtx.TLSConfig
|
||
if tlsConfig == nil {
|
||
tlsConfig = &tls.Config{}
|
||
}
|
||
|
||
serverName := tlsConfig.ServerName
|
||
if serverName == "" {
|
||
serverName = reqCtx.TLSServerName
|
||
}
|
||
if serverName == "" && !tlsConfig.InsecureSkipVerify {
|
||
host, _, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
||
host = addr[:idx]
|
||
} else {
|
||
host = addr
|
||
}
|
||
}
|
||
serverName = host
|
||
}
|
||
if serverName != "" && tlsConfig.ServerName != serverName {
|
||
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
||
tlsConfig.ServerName = serverName
|
||
}
|
||
if traceState != nil {
|
||
traceState.markCustomTLS()
|
||
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
|
||
Network: network,
|
||
Addr: addr,
|
||
ServerName: serverName,
|
||
})
|
||
}
|
||
|
||
// 执行 TLS 握手
|
||
if deadline, ok := ctx.Deadline(); ok {
|
||
_ = conn.SetDeadline(deadline)
|
||
defer conn.SetDeadline(time.Time{})
|
||
}
|
||
|
||
tlsConn := tls.Client(conn, tlsConfig)
|
||
if err := tlsConn.Handshake(); err != nil {
|
||
if traceState != nil {
|
||
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||
Network: network,
|
||
Addr: addr,
|
||
ServerName: serverName,
|
||
Err: err,
|
||
})
|
||
}
|
||
conn.Close()
|
||
return nil, wrapError(err, "tls handshake")
|
||
}
|
||
if traceState != nil {
|
||
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||
Network: network,
|
||
Addr: addr,
|
||
ServerName: serverName,
|
||
ConnectionState: tlsConn.ConnectionState(),
|
||
})
|
||
}
|
||
|
||
return tlsConn, nil
|
||
}
|
||
|
||
/*
|
||
// defaultProxyFunc 默认代理函数
|
||
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
|
||
if req == nil {
|
||
return nil, fmt.Errorf("request is nil")
|
||
}
|
||
|
||
reqCtx := getRequestContext(req.Context())
|
||
if reqCtx.Proxy == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
proxyURL, err := url.Parse(reqCtx.Proxy)
|
||
if err != nil {
|
||
return nil, wrapError(err, "parse proxy url")
|
||
}
|
||
|
||
return proxyURL, nil
|
||
}
|
||
|
||
*/
|