2026-03-08 20:19:40 +08:00
|
|
|
|
package starnet
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"crypto/tls"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"net"
|
2026-03-10 19:55:37 +08:00
|
|
|
|
"strings"
|
2026-03-08 20:19:40 +08:00
|
|
|
|
"time"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 15:39:51 +08:00
|
|
|
|
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
|
|
|
|
|
|
if traceState != nil {
|
|
|
|
|
|
traceState.beginManualDNS()
|
|
|
|
|
|
defer traceState.endManualDNS()
|
|
|
|
|
|
traceState.dnsStart(TraceDNSStartInfo{Host: host})
|
|
|
|
|
|
}
|
|
|
|
|
|
ipAddrs, err := lookup()
|
|
|
|
|
|
if traceState != nil {
|
|
|
|
|
|
traceState.dnsDone(TraceDNSDoneInfo{
|
|
|
|
|
|
Addrs: append([]net.IPAddr(nil), ipAddrs...),
|
|
|
|
|
|
Err: err,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
return ipAddrs, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
|
|
|
|
|
|
if reqCtx == nil {
|
|
|
|
|
|
reqCtx = &RequestContext{}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var addrs []string
|
|
|
|
|
|
|
|
|
|
|
|
if len(reqCtx.CustomIP) > 0 {
|
|
|
|
|
|
for _, ip := range reqCtx.CustomIP {
|
|
|
|
|
|
addrs = append(addrs, joinResolvedHostPort(ip, port))
|
|
|
|
|
|
}
|
|
|
|
|
|
return addrs, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
|
ipAddrs []net.IPAddr
|
|
|
|
|
|
err error
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if reqCtx.LookupIPFn != nil {
|
|
|
|
|
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
|
|
|
|
|
return reqCtx.LookupIPFn(ctx, host)
|
|
|
|
|
|
})
|
|
|
|
|
|
} else if len(reqCtx.CustomDNS) > 0 {
|
|
|
|
|
|
dialTimeout := reqCtx.DialTimeout
|
|
|
|
|
|
if dialTimeout == 0 {
|
|
|
|
|
|
dialTimeout = DefaultDialTimeout
|
|
|
|
|
|
}
|
|
|
|
|
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
|
|
|
|
|
resolver := &net.Resolver{
|
|
|
|
|
|
PreferGo: true,
|
|
|
|
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
|
var lastErr error
|
|
|
|
|
|
for _, dnsServer := range reqCtx.CustomDNS {
|
|
|
|
|
|
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
lastErr = err
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
return conn, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, lastErr
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
|
|
|
|
|
return resolver.LookupIPAddr(ctx, host)
|
|
|
|
|
|
})
|
|
|
|
|
|
} else {
|
|
|
|
|
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
|
|
|
|
|
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, wrapError(err, "lookup ip")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, ipAddr := range ipAddrs {
|
|
|
|
|
|
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
|
|
|
|
|
|
}
|
|
|
|
|
|
return addrs, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func joinResolvedHostPort(host, port string) string {
|
|
|
|
|
|
if port == "" {
|
|
|
|
|
|
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
|
|
|
|
|
|
return "[" + host + "]"
|
|
|
|
|
|
}
|
|
|
|
|
|
return host
|
|
|
|
|
|
}
|
|
|
|
|
|
return net.JoinHostPort(host, port)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
|
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
|
|
|
|
|
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
|
|
|
// 提取配置
|
|
|
|
|
|
reqCtx := getRequestContext(ctx)
|
2026-04-19 15:39:51 +08:00
|
|
|
|
traceState := getTraceState(ctx)
|
2026-03-08 20:19:40 +08:00
|
|
|
|
|
|
|
|
|
|
dialTimeout := reqCtx.DialTimeout
|
|
|
|
|
|
if dialTimeout == 0 {
|
|
|
|
|
|
dialTimeout = DefaultDialTimeout
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 解析地址
|
|
|
|
|
|
host, port, err := net.SplitHostPort(addr)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, wrapError(err, "split host port")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-19 15:39:51 +08:00
|
|
|
|
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
2026-03-08 20:19:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 尝试连接所有地址
|
2026-03-19 16:42:45 +08:00
|
|
|
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
2026-03-08 20:19:40 +08:00
|
|
|
|
var lastErr error
|
|
|
|
|
|
for _, addr := range addrs {
|
2026-03-19 16:42:45 +08:00
|
|
|
|
conn, err := dialer.DialContext(ctx, network, addr)
|
2026-03-08 20:19:40 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
lastErr = err
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return conn, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if lastErr != nil {
|
|
|
|
|
|
return nil, wrapError(lastErr, "dial all addresses failed")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil, fmt.Errorf("no addresses to dial")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// defaultDialTLSFunc 默认 TLS Dial 函数
|
|
|
|
|
|
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
|
|
|
// 先建立 TCP 连接
|
|
|
|
|
|
conn, err := defaultDialFunc(ctx, network, addr)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 提取 TLS 配置
|
|
|
|
|
|
reqCtx := getRequestContext(ctx)
|
2026-04-19 15:39:51 +08:00
|
|
|
|
traceState := getTraceState(ctx)
|
2026-03-08 20:19:40 +08:00
|
|
|
|
tlsConfig := reqCtx.TLSConfig
|
|
|
|
|
|
if tlsConfig == nil {
|
|
|
|
|
|
tlsConfig = &tls.Config{}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-19 15:39:51 +08:00
|
|
|
|
serverName := tlsConfig.ServerName
|
|
|
|
|
|
if serverName == "" {
|
|
|
|
|
|
serverName = reqCtx.TLSServerName
|
|
|
|
|
|
}
|
|
|
|
|
|
if serverName == "" && !tlsConfig.InsecureSkipVerify {
|
2026-03-08 21:38:45 +08:00
|
|
|
|
host, _, err := net.SplitHostPort(addr)
|
|
|
|
|
|
if err != nil {
|
2026-03-10 19:55:37 +08:00
|
|
|
|
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
|
|
|
|
|
host = addr[:idx]
|
|
|
|
|
|
} else {
|
|
|
|
|
|
host = addr
|
|
|
|
|
|
}
|
2026-03-08 21:38:45 +08:00
|
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
|
serverName = host
|
|
|
|
|
|
}
|
|
|
|
|
|
if serverName != "" && tlsConfig.ServerName != serverName {
|
2026-03-08 21:38:45 +08:00
|
|
|
|
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
2026-04-19 15:39:51 +08:00
|
|
|
|
tlsConfig.ServerName = serverName
|
|
|
|
|
|
}
|
|
|
|
|
|
if traceState != nil {
|
|
|
|
|
|
traceState.markCustomTLS()
|
|
|
|
|
|
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
|
|
|
|
|
|
Network: network,
|
|
|
|
|
|
Addr: addr,
|
|
|
|
|
|
ServerName: serverName,
|
|
|
|
|
|
})
|
2026-03-08 21:38:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
|
// 执行 TLS 握手
|
2026-03-19 16:42:45 +08:00
|
|
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
|
|
|
|
_ = conn.SetDeadline(deadline)
|
|
|
|
|
|
defer conn.SetDeadline(time.Time{})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
|
tlsConn := tls.Client(conn, tlsConfig)
|
|
|
|
|
|
if err := tlsConn.Handshake(); err != nil {
|
2026-04-19 15:39:51 +08:00
|
|
|
|
if traceState != nil {
|
|
|
|
|
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
|
|
|
|
|
Network: network,
|
|
|
|
|
|
Addr: addr,
|
|
|
|
|
|
ServerName: serverName,
|
|
|
|
|
|
Err: err,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
|
conn.Close()
|
|
|
|
|
|
return nil, wrapError(err, "tls handshake")
|
|
|
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
|
if traceState != nil {
|
|
|
|
|
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
|
|
|
|
|
Network: network,
|
|
|
|
|
|
Addr: addr,
|
|
|
|
|
|
ServerName: serverName,
|
|
|
|
|
|
ConnectionState: tlsConn.ConnectionState(),
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
|
|
|
|
|
|
|
return tlsConn, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
// defaultProxyFunc 默认代理函数
|
|
|
|
|
|
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
|
|
|
|
|
|
if req == nil {
|
|
|
|
|
|
return nil, fmt.Errorf("request is nil")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
reqCtx := getRequestContext(req.Context())
|
|
|
|
|
|
if reqCtx.Proxy == "" {
|
|
|
|
|
|
return nil, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
proxyURL, err := url.Parse(reqCtx.Proxy)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, wrapError(err, "parse proxy url")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return proxyURL, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
*/
|