diff --git a/client.go b/client.go index 3afc593..14aed5b 100644 --- a/client.go +++ b/client.go @@ -46,7 +46,7 @@ func NewClient(opts ...RequestOpt) (*Client, error) { if req.config.Network.Timeout > 0 { httpClient.Timeout = req.config.Network.Timeout } - + */ // 如果有自定义 Transport @@ -172,10 +172,11 @@ func (c *Client) Clone() *Client { func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client { if transport, ok := c.client.Transport.(*Transport); ok { transport.mu.Lock() - if transport.base.TLSClientConfig == nil { - transport.base.TLSClientConfig = &tls.Config{} + if tlsConfig != nil { + transport.base.TLSClientConfig = tlsConfig.Clone() + } else { + transport.base.TLSClientConfig = nil } - transport.base.TLSClientConfig = tlsConfig transport.mu.Unlock() } return c @@ -187,6 +188,8 @@ func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client { transport.mu.Lock() if transport.base.TLSClientConfig == nil { transport.base.TLSClientConfig = &tls.Config{} + } else { + transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone() } transport.base.TLSClientConfig.InsecureSkipVerify = skip transport.mu.Unlock() diff --git a/dialer.go b/dialer.go index 1f84667..6fb64ec 100644 --- a/dialer.go +++ b/dialer.go @@ -115,6 +115,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er tlsConfig = &tls.Config{} } + // ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify,自动设置 + if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(addr) + if err != nil { + // addr 可能没有端口,直接用 addr + host = addr + } + tlsConfig = tlsConfig.Clone() // 避免修改原 config + tlsConfig.ServerName = host + } + // 执行 TLS 握手 tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { diff --git a/tls_test.go b/tls_test.go index 220b6db..01e4bb5 100644 --- a/tls_test.go +++ b/tls_test.go @@ -1,10 +1,12 @@ package starnet import ( + "context" "crypto/tls" "net/http" "net/http/httptest" "testing" + "time" ) func TestRequestSkipTLSVerify(t *testing.T) { @@ -100,3 +102,128 @@ func TestRequestLevelTLSOverride(t *testing.T) { t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) } } + +func TestRequestTls(t *testing.T) { + resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } + t.Logf("Response: %v", resp.Body().MustString()) + client, err := NewClient() + if err != nil { + t.Fatalf("NewClient() error: %v", err) + } + resp, err = client.NewSimpleRequest("https://www.b612.me", "GET", + WithHeader("hello", "world"), + WithContext(context.Background()), + WithBearerToken("ddddddd")).Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } + t.Logf("Response: %v", resp.Body().MustString()) +} + +func TestTLSWithProxyPath(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Fatal(err) + } + + req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", + WithTimeout(10*time.Second), + WithProxy("http://127.0.0.1:29992"), + ) + if err != nil { + t.Fatal(err) + } + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + t.Log(resp.Status) +} + +func TestTLSWithProxyBug(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Fatal(err) + } + + // 关键:使用 WithProxy 触发 needsDynamicTransport + // 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支 + req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET", + WithTimeout(10*time.Second), + WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport + ) + if err != nil { + t.Fatal(err) + } + + resp, err := req.Do() + if err != nil { + // 修复前会报:tls: either ServerName or InsecureSkipVerify must be specified + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + t.Logf("Status: %s", resp.Status) +} + +// 更精准的复现:直接测试有问题的分支 +func TestTLSDialWithoutServerName(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Fatal(err) + } + + // 使用 WithCustomIP 也能触发 defaultDialTLSFunc + req, err := client.NewRequest("https://www.google.com", "GET", + WithTimeout(10*time.Second), + WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP + ) + if err != nil { + t.Fatal(err) + } + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + t.Logf("Status: %s", resp.Status) +} + +// 最小复现:只要触发 needsDynamicTransport 即可 +func TestMinimalTLSBug(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Fatal(err) + } + + // WithDialTimeout 也会触发动态 transport + req, err := client.NewRequest("https://www.baidu.com", "GET", + WithDialTimeout(5*time.Second), + ) + if err != nil { + t.Fatal(err) + } + + resp, err := req.Do() + if err != nil { + // 修复前必现:tls handshake: tls: either ServerName or InsecureSkipVerify must be specified + t.Fatalf("Do error: %v", err) + } + defer resp.Close() + t.Logf("Status: %s", resp.Status) +}