bug fix:tls自定义时,没有设置servername的问题
This commit is contained in:
parent
50aef48d49
commit
1bb30514ec
11
client.go
11
client.go
@ -46,7 +46,7 @@ func NewClient(opts ...RequestOpt) (*Client, error) {
|
|||||||
if req.config.Network.Timeout > 0 {
|
if req.config.Network.Timeout > 0 {
|
||||||
httpClient.Timeout = req.config.Network.Timeout
|
httpClient.Timeout = req.config.Network.Timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// 如果有自定义 Transport
|
// 如果有自定义 Transport
|
||||||
@ -172,10 +172,11 @@ func (c *Client) Clone() *Client {
|
|||||||
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
||||||
if transport, ok := c.client.Transport.(*Transport); ok {
|
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||||
transport.mu.Lock()
|
transport.mu.Lock()
|
||||||
if transport.base.TLSClientConfig == nil {
|
if tlsConfig != nil {
|
||||||
transport.base.TLSClientConfig = &tls.Config{}
|
transport.base.TLSClientConfig = tlsConfig.Clone()
|
||||||
|
} else {
|
||||||
|
transport.base.TLSClientConfig = nil
|
||||||
}
|
}
|
||||||
transport.base.TLSClientConfig = tlsConfig
|
|
||||||
transport.mu.Unlock()
|
transport.mu.Unlock()
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
@ -187,6 +188,8 @@ func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
|
|||||||
transport.mu.Lock()
|
transport.mu.Lock()
|
||||||
if transport.base.TLSClientConfig == nil {
|
if transport.base.TLSClientConfig == nil {
|
||||||
transport.base.TLSClientConfig = &tls.Config{}
|
transport.base.TLSClientConfig = &tls.Config{}
|
||||||
|
} else {
|
||||||
|
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
|
||||||
}
|
}
|
||||||
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||||
transport.mu.Unlock()
|
transport.mu.Unlock()
|
||||||
|
|||||||
11
dialer.go
11
dialer.go
@ -115,6 +115,17 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
tlsConfig = &tls.Config{}
|
tlsConfig = &tls.Config{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify,自动设置
|
||||||
|
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
// addr 可能没有端口,直接用 addr
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
// 执行 TLS 握手
|
// 执行 TLS 握手
|
||||||
tlsConn := tls.Client(conn, tlsConfig)
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
|||||||
127
tls_test.go
127
tls_test.go
@ -1,10 +1,12 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRequestSkipTLSVerify(t *testing.T) {
|
func TestRequestSkipTLSVerify(t *testing.T) {
|
||||||
@ -100,3 +102,128 @@ func TestRequestLevelTLSOverride(t *testing.T) {
|
|||||||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRequestTls(t *testing.T) {
|
||||||
|
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
t.Logf("Response: %v", resp.Body().MustString())
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClient() error: %v", err)
|
||||||
|
}
|
||||||
|
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
|
||||||
|
WithHeader("hello", "world"),
|
||||||
|
WithContext(context.Background()),
|
||||||
|
WithBearerToken("ddddddd")).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
t.Logf("Response: %v", resp.Body().MustString())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSWithProxyPath(t *testing.T) {
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy("http://127.0.0.1:29992"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Log(resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSWithProxyBug(t *testing.T) {
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关键:使用 WithProxy 触发 needsDynamicTransport
|
||||||
|
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
|
||||||
|
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
// 修复前会报:tls: either ServerName or InsecureSkipVerify must be specified
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更精准的复现:直接测试有问题的分支
|
||||||
|
func TestTLSDialWithoutServerName(t *testing.T) {
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
|
||||||
|
req, err := client.NewRequest("https://www.google.com", "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最小复现:只要触发 needsDynamicTransport 即可
|
||||||
|
func TestMinimalTLSBug(t *testing.T) {
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialTimeout 也会触发动态 transport
|
||||||
|
req, err := client.NewRequest("https://www.baidu.com", "GET",
|
||||||
|
WithDialTimeout(5*time.Second),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
// 修复前必现:tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user