230 lines
5.5 KiB
Go
230 lines
5.5 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func TestRequestSkipTLSVerify(t *testing.T) {
|
||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
w.Write([]byte("OK"))
|
||
}))
|
||
defer server.Close()
|
||
|
||
// Without skip verify (should fail)
|
||
req := NewSimpleRequest(server.URL, "GET")
|
||
_, err := req.Do()
|
||
if err == nil {
|
||
t.Error("Expected TLS error without skip verify, got nil")
|
||
}
|
||
|
||
// With skip verify (should succeed)
|
||
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
|
||
resp, err := req2.Do()
|
||
if err != nil {
|
||
t.Fatalf("Do() with skip verify error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
body, _ := resp.Body().String()
|
||
if body != "OK" {
|
||
t.Errorf("Body = %v; want OK", body)
|
||
}
|
||
}
|
||
|
||
func TestRequestCustomTLSConfig(t *testing.T) {
|
||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
w.Write([]byte("OK"))
|
||
}))
|
||
defer server.Close()
|
||
|
||
tlsConfig := &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
MinVersion: tls.VersionTLS12,
|
||
}
|
||
|
||
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
|
||
resp, err := req.Do()
|
||
if err != nil {
|
||
t.Fatalf("Do() error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||
}
|
||
}
|
||
|
||
func TestClientDefaultTLSConfig(t *testing.T) {
|
||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
}))
|
||
defer server.Close()
|
||
|
||
client := NewClientNoErr()
|
||
client.SetDefaultSkipTLSVerify(true)
|
||
|
||
resp, err := client.Get(server.URL)
|
||
if err != nil {
|
||
t.Fatalf("Get() error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||
}
|
||
}
|
||
|
||
func TestRequestLevelTLSOverride(t *testing.T) {
|
||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
}))
|
||
defer server.Close()
|
||
|
||
// Client level: skip verify = false
|
||
client := NewClientNoErr()
|
||
client.SetDefaultSkipTLSVerify(false)
|
||
|
||
// Request level: skip verify = true (should override)
|
||
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
|
||
if err != nil {
|
||
t.Fatalf("Get() error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||
}
|
||
}
|
||
|
||
func TestRequestTls(t *testing.T) {
|
||
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
|
||
if err != nil {
|
||
t.Fatalf("Do() error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||
}
|
||
t.Logf("Response: %v", resp.Body().MustString())
|
||
client, err := NewClient()
|
||
if err != nil {
|
||
t.Fatalf("NewClient() error: %v", err)
|
||
}
|
||
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
|
||
WithHeader("hello", "world"),
|
||
WithContext(context.Background()),
|
||
WithBearerToken("ddddddd")).Do()
|
||
if err != nil {
|
||
t.Fatalf("Do() error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||
}
|
||
t.Logf("Response: %v", resp.Body().MustString())
|
||
}
|
||
|
||
func TestTLSWithProxyPath(t *testing.T) {
|
||
client, err := NewClient()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
||
WithTimeout(10*time.Second),
|
||
WithProxy("http://127.0.0.1:29992"),
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
resp, err := req.Do()
|
||
if err != nil {
|
||
t.Fatalf("Do error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
t.Log(resp.Status)
|
||
}
|
||
|
||
func TestTLSWithProxyBug(t *testing.T) {
|
||
client, err := NewClient()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// 关键:使用 WithProxy 触发 needsDynamicTransport
|
||
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
|
||
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
|
||
WithTimeout(10*time.Second),
|
||
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
resp, err := req.Do()
|
||
if err != nil {
|
||
// 修复前会报:tls: either ServerName or InsecureSkipVerify must be specified
|
||
t.Fatalf("Do error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
t.Logf("Status: %s", resp.Status)
|
||
}
|
||
|
||
// 更精准的复现:直接测试有问题的分支
|
||
func TestTLSDialWithoutServerName(t *testing.T) {
|
||
client, err := NewClient()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
|
||
req, err := client.NewRequest("https://www.google.com", "GET",
|
||
WithTimeout(10*time.Second),
|
||
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
resp, err := req.Do()
|
||
if err != nil {
|
||
t.Fatalf("Do error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
t.Logf("Status: %s", resp.Status)
|
||
}
|
||
|
||
// 最小复现:只要触发 needsDynamicTransport 即可
|
||
func TestMinimalTLSBug(t *testing.T) {
|
||
client, err := NewClient()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// WithDialTimeout 也会触发动态 transport
|
||
req, err := client.NewRequest("https://www.baidu.com", "GET",
|
||
WithDialTimeout(5*time.Second),
|
||
)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
resp, err := req.Do()
|
||
if err != nil {
|
||
// 修复前必现:tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
|
||
t.Fatalf("Do error: %v", err)
|
||
}
|
||
defer resp.Close()
|
||
t.Logf("Status: %s", resp.Status)
|
||
}
|