package starnet import ( "context" "crypto/tls" "fmt" "net" "strings" "time" ) func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) { if traceState != nil { traceState.beginManualDNS() defer traceState.endManualDNS() traceState.dnsStart(TraceDNSStartInfo{Host: host}) } ipAddrs, err := lookup() if traceState != nil { traceState.dnsDone(TraceDNSDoneInfo{ Addrs: append([]net.IPAddr(nil), ipAddrs...), Err: err, }) } return ipAddrs, err } func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) { if reqCtx == nil { reqCtx = &RequestContext{} } var addrs []string if len(reqCtx.CustomIP) > 0 { for _, ip := range reqCtx.CustomIP { addrs = append(addrs, joinResolvedHostPort(ip, port)) } return addrs, nil } var ( ipAddrs []net.IPAddr err error ) if reqCtx.LookupIPFn != nil { ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { return reqCtx.LookupIPFn(ctx, host) }) } else if len(reqCtx.CustomDNS) > 0 { dialTimeout := reqCtx.DialTimeout if dialTimeout == 0 { dialTimeout = DefaultDialTimeout } dialer := &net.Dialer{Timeout: dialTimeout} resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { var lastErr error for _, dnsServer := range reqCtx.CustomDNS { conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53")) if err != nil { lastErr = err continue } return conn, nil } return nil, lastErr }, } ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { return resolver.LookupIPAddr(ctx, host) }) } else { ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) { return net.DefaultResolver.LookupIPAddr(ctx, host) }) } if err != nil { return nil, wrapError(err, "lookup ip") } for _, ipAddr := range ipAddrs { addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port)) } return addrs, nil } func joinResolvedHostPort(host, port string) string { if port == "" { if ip := net.ParseIP(host); ip != nil && ip.To4() == nil { return "[" + host + "]" } return host } return net.JoinHostPort(host, port) } // defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS) func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 提取配置 reqCtx := getRequestContext(ctx) traceState := getTraceState(ctx) dialTimeout := reqCtx.DialTimeout if dialTimeout == 0 { dialTimeout = DefaultDialTimeout } // 解析地址 host, port, err := net.SplitHostPort(addr) if err != nil { return nil, wrapError(err, "split host port") } addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState) if err != nil { return nil, err } // 尝试连接所有地址 dialer := &net.Dialer{Timeout: dialTimeout} var lastErr error for _, addr := range addrs { conn, err := dialer.DialContext(ctx, network, addr) if err != nil { lastErr = err continue } return conn, nil } if lastErr != nil { return nil, wrapError(lastErr, "dial all addresses failed") } return nil, fmt.Errorf("no addresses to dial") } // defaultDialTLSFunc 默认 TLS Dial 函数 func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 先建立 TCP 连接 conn, err := defaultDialFunc(ctx, network, addr) if err != nil { return nil, err } // 提取 TLS 配置 reqCtx := getRequestContext(ctx) traceState := getTraceState(ctx) tlsConfig := reqCtx.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{} } serverName := tlsConfig.ServerName if serverName == "" { serverName = reqCtx.TLSServerName } if serverName == "" && !tlsConfig.InsecureSkipVerify { host, _, err := net.SplitHostPort(addr) if err != nil { if idx := strings.LastIndex(addr, ":"); idx > 0 { host = addr[:idx] } else { host = addr } } serverName = host } if serverName != "" && tlsConfig.ServerName != serverName { tlsConfig = tlsConfig.Clone() // 避免修改原 config tlsConfig.ServerName = serverName } if traceState != nil { traceState.markCustomTLS() traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{ Network: network, Addr: addr, ServerName: serverName, }) } // 执行 TLS 握手 if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) defer conn.SetDeadline(time.Time{}) } tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { if traceState != nil { traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{ Network: network, Addr: addr, ServerName: serverName, Err: err, }) } conn.Close() return nil, wrapError(err, "tls handshake") } if traceState != nil { traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{ Network: network, Addr: addr, ServerName: serverName, ConnectionState: tlsConn.ConnectionState(), }) } return tlsConn, nil } /* // defaultProxyFunc 默认代理函数 func defaultProxyFunc(req *http.Request) (*url.URL, error) { if req == nil { return nil, fmt.Errorf("request is nil") } reqCtx := getRequestContext(req.Context()) if reqCtx.Proxy == "" { return nil, nil } proxyURL, err := url.Parse(reqCtx.Proxy) if err != nil { return nil, wrapError(err, "parse proxy url") } return proxyURL, nil } */