package starnet import ( "context" "crypto/tls" "encoding/json" "fmt" "io" "net" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" ) // TestComplexScenario1_RequestLevelConfigOverride 测试请求级配置覆盖 Client 级配置 func TestComplexScenario1_RequestLevelConfigOverride(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(150 * time.Millisecond) w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() // Client 级别:5 秒超时 client := NewClientNoErr(WithTimeout(5 * time.Second)) // 请求 1:使用 Client 的超时(应该成功) resp1, err := client.Get(server.URL) if err != nil { t.Fatalf("Request 1 error: %v", err) } resp1.Close() // 请求 2:请求级别覆盖为 100ms(应该超时) start := time.Now() _, err = client.Get(server.URL, WithTimeout(100*time.Millisecond)) elapsed := time.Since(start) if err == nil { t.Error("Request 2 should timeout, got nil error") } if elapsed > 500*time.Millisecond { t.Errorf("Request 2 timeout took too long: %v", elapsed) } // 请求 3:再次使用 Client 的超时(应该成功,验证没有副作用) resp3, err := client.Get(server.URL) if err != nil { t.Fatalf("Request 3 error: %v", err) } resp3.Close() } // TestComplexScenario2_TLSConfigPriority 测试 TLS 配置的优先级 func TestComplexScenario2_TLSConfigPriority(t *testing.T) { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() // 场景 1:Client 级别设置 SkipVerify client := NewClientNoErr() client.SetDefaultSkipTLSVerify(true) resp1, err := client.Get(server.URL) if err != nil { t.Fatalf("Scenario 1 error: %v", err) } resp1.Close() // 场景 2:请求级别设置自定义 TLS Config(应该覆盖 Client 级别) customTLS := &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, } resp2, err := client.Get(server.URL, WithTLSConfig(customTLS)) if err != nil { t.Fatalf("Scenario 2 error: %v", err) } resp2.Close() // 场景 3:请求级别只设置 SkipVerify(不设置完整 TLS Config) resp3, err := client.Get(server.URL, WithSkipTLSVerify(true)) if err != nil { t.Fatalf("Scenario 3 error: %v", err) } resp3.Close() // 场景 4:新 Client 不设置任何 TLS 配置(应该失败) client2 := NewClientNoErr() _, err = client2.Get(server.URL) if err == nil { t.Error("Scenario 4 should fail with TLS error, got nil") } } // TestComplexScenario3_ConnectionPoolReuse 测试连接池复用 func TestComplexScenario3_ConnectionPoolReuse(t *testing.T) { var connCount int64 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt64(&connCount, 1) w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() client := NewClientNoErr() // 发送 10 个请求,应该复用连接 for i := 0; i < 10; i++ { resp, err := client.Get(server.URL) if err != nil { t.Fatalf("Request %d error: %v", i, err) } // 必须读取并关闭 body 才能复用连接 io.ReadAll(resp.Body().raw) resp.Close() } // 验证连接被复用(实际连接数应该远小于请求数) // 注意:这个测试可能不稳定,因为连接池行为依赖于时间和系统状态 t.Logf("Total handler calls: %d", atomic.LoadInt64(&connCount)) } // TestComplexScenario4_CustomDNSWithFallback 测试自定义 DNS 和回退机制 func TestComplexScenario4_CustomDNSWithFallback(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() // 提取服务器的实际 IP 和端口 serverURL := server.URL host := strings.TrimPrefix(serverURL, "http://") // 场景 1:使用自定义 IP(直接指定) parts := strings.Split(host, ":") if len(parts) != 2 { t.Fatalf("Invalid server URL: %s", serverURL) } ip := parts[0] port := parts[1] // 构造一个使用域名的 URL testURL := fmt.Sprintf("http://test.example.com:%s", port) req := NewSimpleRequest(testURL, "GET").SetCustomIP([]string{ip}) resp, err := req.Do() if err != nil { t.Fatalf("Custom IP request error: %v", err) } resp.Close() // 场景 2:使用自定义 DNS 解析函数 lookupCalled := false customLookup := func(ctx context.Context, host string) ([]net.IPAddr, error) { lookupCalled = true // 返回实际的 IP return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil } req2 := NewSimpleRequest(testURL, "GET").SetLookupFunc(customLookup) resp2, err := req2.Do() if err != nil { t.Fatalf("Custom lookup request error: %v", err) } resp2.Close() if !lookupCalled { t.Error("Custom lookup function was not called") } } // TestComplexScenario5_ConcurrentRequestsWithDifferentConfigs 测试并发请求使用不同配置 func TestComplexScenario5_ConcurrentRequestsWithDifferentConfigs(t *testing.T) { // 创建多个服务器,模拟不同的延迟 servers := make([]*httptest.Server, 3) for i := range servers { delay := time.Duration(i*50) * time.Millisecond idx := i // ← 修复:创建局部变量 servers[i] = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(delay) w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf("Server %d", idx))) // ← 使用局部变量 })) defer servers[i].Close() } client := NewClientNoErr() var wg sync.WaitGroup results := make([]string, 3) errors := make([]error, 3) // 并发发送请求,每个请求使用不同的超时 for i := 0; i < 3; i++ { wg.Add(1) go func(idx int) { defer wg.Done() timeout := time.Duration((idx+1)*100) * time.Millisecond resp, err := client.Get(servers[idx].URL, WithTimeout(timeout)) if err != nil { errors[idx] = err return } defer resp.Close() body, _ := resp.Body().String() results[idx] = body }(i) } wg.Wait() // 验证结果 for i := 0; i < 3; i++ { if errors[i] != nil { t.Errorf("Request %d error: %v", i, errors[i]) } expected := fmt.Sprintf("Server %d", i) if results[i] != expected { t.Errorf("Request %d result = %v; want %v", i, results[i], expected) } } } // TestComplexScenario6_RequestCloneIndependence 测试克隆请求的独立性 func TestComplexScenario6_RequestCloneIndependence(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 返回所有 headers for k, v := range r.Header { w.Header().Set(k, strings.Join(v, ",")) } w.WriteHeader(http.StatusOK) })) defer server.Close() // 创建基础请求 baseReq := NewSimpleRequest(server.URL, "GET"). SetHeader("X-Base", "base-value"). SetTimeout(5 * time.Second) // 克隆并修改 req1 := baseReq.Clone(). SetHeader("X-Request", "request-1"). SetTimeout(1 * time.Second) req2 := baseReq.Clone(). SetHeader("X-Request", "request-2"). SetTimeout(2 * time.Second) // 执行请求 resp1, err := req1.Do() if err != nil { t.Fatalf("Request 1 error: %v", err) } defer resp1.Close() resp2, err := req2.Do() if err != nil { t.Fatalf("Request 2 error: %v", err) } defer resp2.Close() // 验证 headers 独立 if resp1.Header.Get("X-Request") != "request-1" { t.Errorf("Request 1 header = %v; want request-1", resp1.Header.Get("X-Request")) } if resp2.Header.Get("X-Request") != "request-2" { t.Errorf("Request 2 header = %v; want request-2", resp2.Header.Get("X-Request")) } // 验证基础请求未被修改 resp3, err := baseReq.Do() if err != nil { t.Fatalf("Base request error: %v", err) } defer resp3.Close() if resp3.Header.Get("X-Request") != "" { t.Errorf("Base request should not have X-Request header, got %v", resp3.Header.Get("X-Request")) } } // TestComplexScenario7_ErrorAccumulation 测试错误累积机制 func TestComplexScenario7_ErrorAccumulation(t *testing.T) { // 场景 1:链式调用中的错误累积 req := NewSimpleRequest("://invalid-url", "GET"). SetHeader("X-Test", "value"). AddQuery("key", "value") // 错误应该被累积,不会 panic if req.Err() == nil { t.Error("Expected error for invalid URL, got nil") } // 后续操作应该被忽略 req.SetTimeout(5 * time.Second) // Do() 应该返回累积的错误 _, err := req.Do() if err == nil { t.Error("Do() should return accumulated error, got nil") } // 场景 2:无效的方法 req2 := NewSimpleRequest("http://example.com", "INVALID METHOD!") if req2.Err() == nil { t.Error("Expected error for invalid method, got nil") } // 场景 3:无效的 IP req3 := NewSimpleRequest("http://example.com", "GET"). SetCustomIP([]string{"invalid-ip"}) if req3.Err() == nil { t.Error("Expected error for invalid IP, got nil") } } // TestComplexScenario8_DialTimeoutVsRequestTimeout 测试 DialTimeout 和 Timeout 的区别 func TestComplexScenario8_DialTimeoutVsRequestTimeout(t *testing.T) { // 场景 1:DialTimeout - 连接超时 start := time.Now() req := NewSimpleRequest("http://192.0.2.1:80", "GET"). SetDialTimeout(100 * time.Millisecond) _, err := req.Do() elapsed := time.Since(start) if err == nil { t.Error("Expected dial timeout error, got nil") } if elapsed > 2*time.Second { t.Errorf("Dial timeout took too long: %v", elapsed) } // 场景 2:Timeout - 总超时(包括响应读取) slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) w.WriteHeader(http.StatusOK) })) defer slowServer.Close() start2 := time.Now() req2 := NewSimpleRequest(slowServer.URL, "GET"). SetTimeout(100 * time.Millisecond) _, err2 := req2.Do() elapsed2 := time.Since(start2) if err2 == nil { t.Error("Expected request timeout error, got nil") } if elapsed2 > 500*time.Millisecond { t.Errorf("Request timeout took too long: %v", elapsed2) } } // TestComplexScenario9_MultipartUploadWithProgress 测试带进度的文件上传 func TestComplexScenario9_MultipartUploadWithProgress(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := r.ParseMultipartForm(10 << 20) if err != nil { t.Errorf("ParseMultipartForm error: %v", err) w.WriteHeader(http.StatusBadRequest) return } // 验证表单字段 if r.FormValue("name") != "test" { t.Errorf("name = %v; want test", r.FormValue("name")) } // 验证文件 file, header, err := r.FormFile("file") if err != nil { t.Errorf("FormFile error: %v", err) w.WriteHeader(http.StatusBadRequest) return } defer file.Close() content, _ := io.ReadAll(file) w.WriteHeader(http.StatusOK) w.Write([]byte(fmt.Sprintf("Received: %s (%d bytes)", header.Filename, len(content)))) })) defer server.Close() // 创建测试数据 fileContent := strings.Repeat("test data ", 1000) // ~10KB reader := strings.NewReader(fileContent) // 跟踪进度 var progressCalls int64 var lastUploaded int64 req := NewSimpleRequest(server.URL, "POST"). AddFormData("name", "test"). AddFileStream("file", "test.txt", int64(len(fileContent)), reader). SetUploadProgress(func(filename string, uploaded, total int64) { atomic.AddInt64(&progressCalls, 1) atomic.StoreInt64(&lastUploaded, uploaded) if filename != "test.txt" { t.Errorf("filename = %v; want test.txt", filename) } if total != int64(len(fileContent)) { t.Errorf("total = %v; want %v", total, len(fileContent)) } }) resp, err := req.Do() if err != nil { t.Fatalf("Upload error: %v", err) } defer resp.Close() // 验证进度回调被调用 if atomic.LoadInt64(&progressCalls) == 0 { t.Error("Progress callback was not called") } // 验证最终上传量 if atomic.LoadInt64(&lastUploaded) != int64(len(fileContent)) { t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent)) } body, _ := resp.Body().String() if !strings.Contains(body, "test.txt") { t.Errorf("Response should contain filename, got: %v", body) } } // TestComplexScenario10_ClientCloneWithOptions 测试 Client 克隆和选项继承 func TestComplexScenario10_ClientCloneWithOptions(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.Header.Get("X-Client-ID"))) })) defer server.Close() // 创建带选项的 Client client1 := NewClientNoErr( WithTimeout(5*time.Second), WithHeader("X-Client-ID", "client-1"), ) // 克隆 Client client2 := client1.Clone() client2.AddOptions(WithHeader("X-Extra", "extra-value")) // 测试 client1 resp1, err := client1.Get(server.URL) if err != nil { t.Fatalf("Client 1 error: %v", err) } defer resp1.Close() body1, _ := resp1.Body().String() if body1 != "client-1" { t.Errorf("Client 1 response = %v; want client-1", body1) } // 测试 client2(应该继承 client1 的选项) resp2, err := client2.Get(server.URL) if err != nil { t.Fatalf("Client 2 error: %v", err) } defer resp2.Close() body2, _ := resp2.Body().String() if body2 != "client-1" { t.Errorf("Client 2 response = %v; want client-1", body2) } // 验证 client1 未被修改 opts1 := client1.RequestOptions() opts2 := client2.RequestOptions() if len(opts1) >= len(opts2) { t.Errorf("Client 2 should have more options than Client 1") } } // TestComplexScenario11_ContextCancellation 测试 Context 取消 func TestComplexScenario11_ContextCancellation(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) w.WriteHeader(http.StatusOK) })) defer server.Close() ctx, cancel := context.WithCancel(context.Background()) // 在 500ms 后取消 go func() { time.Sleep(500 * time.Millisecond) cancel() }() req := NewSimpleRequestWithContext(ctx, server.URL, "GET") start := time.Now() _, err := req.Do() elapsed := time.Since(start) if err == nil { t.Error("Expected context cancellation error, got nil") } if elapsed > 1*time.Second { t.Errorf("Context cancellation took too long: %v", elapsed) } } // TestComplexScenario12_RedirectWithCookies 测试重定向时的 Cookie 处理 func TestComplexScenario12_RedirectWithCookies(t *testing.T) { var redirectCount int server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if redirectCount < 2 { redirectCount++ // 设置 Cookie 并重定向 http.SetCookie(w, &http.Cookie{ Name: fmt.Sprintf("cookie%d", redirectCount), Value: fmt.Sprintf("value%d", redirectCount), Path: "/", }) http.Redirect(w, r, "/final", http.StatusFound) return } // 最终响应 w.WriteHeader(http.StatusOK) w.Write([]byte("final")) })) defer server.Close() // 测试自动跟随重定向 client := NewClientNoErr() resp, err := client.Get(server.URL) if err != nil { t.Fatalf("Get error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) } body, _ := resp.Body().String() if body != "final" { t.Errorf("Body = %v; want final", body) } // 测试禁用重定向 redirectCount = 0 client.DisableRedirect() resp2, err := client.Get(server.URL) if err != nil { t.Fatalf("Get error: %v", err) } defer resp2.Close() if resp2.StatusCode != http.StatusFound { t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound) } // 验证 Set-Cookie cookies := resp2.Cookies() if len(cookies) == 0 { t.Error("Expected cookies in redirect response") } } // TestDefaultsSetDefaultClient 测试设置默认 Client func TestDefaultsSetDefaultClient(t *testing.T) { // 保存原始的默认 Client originalClient := DefaultClient() // 创建自定义 Client customClient := NewClientNoErr(WithTimeout(1 * time.Second)) SetDefaultClient(customClient) // 验证默认 Client 已更改 if DefaultClient() != customClient { t.Error("SetDefaultClient did not update default client") } // 恢复原始 Client SetDefaultClient(originalClient) } // TestDefaultsSetDefaultHTTPClient 测试设置默认 HTTP Client func TestDefaultsSetDefaultHTTPClient(t *testing.T) { // 保存原始的默认 HTTP Client originalHTTPClient := DefaultHTTPClient() // 创建自定义 HTTP Client customHTTPClient := &http.Client{ Timeout: 2 * time.Second, } SetDefaultHTTPClient(customHTTPClient) // 验证默认 HTTP Client 已更改 if DefaultHTTPClient() != customHTTPClient { t.Error("SetDefaultHTTPClient did not update default http client") } // 恢复原始 HTTP Client SetDefaultHTTPClient(originalHTTPClient) } // TestDefaultsHeadMethod 测试 Head 方法 func TestDefaultsHeadMethod(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodHead { t.Errorf("Method = %v; want HEAD", r.Method) } w.Header().Set("X-Custom", "test-value") w.WriteHeader(http.StatusOK) })) defer server.Close() resp, err := Head(server.URL) if err != nil { t.Fatalf("Head() error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) } // HEAD 请求应该有 headers 但没有 body if resp.Header.Get("X-Custom") != "test-value" { t.Errorf("Header X-Custom = %v; want test-value", resp.Header.Get("X-Custom")) } } // TestProxyConfiguration 测试代理配置 func TestProxyConfiguration(t *testing.T) { // 创建目标服务器 targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("target")) })) defer targetServer.Close() // 创建代理服务器 proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 简单的代理逻辑 w.Header().Set("X-Proxied", "true") w.WriteHeader(http.StatusOK) w.Write([]byte("proxied")) })) defer proxyServer.Close() // 测试 WithProxy req := NewSimpleRequest(targetServer.URL, "GET").SetProxy(proxyServer.URL) // 验证代理配置被设置 if req.config.Network.Proxy != proxyServer.URL { t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL) } // 注意:实际的代理测试需要真实的代理服务器 // 这里只验证配置是否正确设置 } // TestWithRawRequest 测试 WithRawRequest func TestWithRawRequest(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-Custom") != "raw-value" { t.Errorf("X-Custom header = %v; want raw-value", r.Header.Get("X-Custom")) } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() // 创建原始 http.Request rawReq, _ := http.NewRequest("GET", server.URL, nil) rawReq.Header.Set("X-Custom", "raw-value") // 使用 WithRawRequest req := NewSimpleRequest("", "GET", WithRawRequest(rawReq)) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() body, _ := resp.Body().String() if body != "OK" { t.Errorf("Body = %v; want OK", body) } } // TestWithContentLength 测试 WithContentLength func TestWithContentLength(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength != 9 { t.Errorf("ContentLength = %v; want 9", r.ContentLength) } w.WriteHeader(http.StatusOK) })) defer server.Close() data := []byte("test data") resp, err := Post(server.URL, WithBody(data), WithContentLength(int64(len(data)))) // 一致 if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestWithAutoCalcContentLength 测试自动计算 Content-Length func TestWithAutoCalcContentLength(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证 Content-Length 被正确设置 if r.ContentLength <= 0 { t.Errorf("ContentLength = %v; want > 0", r.ContentLength) } w.WriteHeader(http.StatusOK) })) defer server.Close() data := strings.NewReader("test data for auto calc") resp, err := Post(server.URL, WithBodyReader(data), WithAutoCalcContentLength(true)) if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestChunkedTransferEncoding 测试 Chunked 传输编码 func TestChunkedTransferEncoding(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证使用了 chunked 编码 if len(r.TransferEncoding) > 0 && r.TransferEncoding[0] == "chunked" { w.Header().Set("X-Chunked", "true") } w.WriteHeader(http.StatusOK) })) defer server.Close() data := []byte("test data") resp, err := Post(server.URL, WithBody(data), WithContentLength(-1)) // -1 强制使用 chunked if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestWithFormDataMap 测试 WithFormDataMap func TestWithFormDataMap(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.ParseForm() if r.FormValue("key1") != "value1" { t.Errorf("key1 = %v; want value1", r.FormValue("key1")) } if r.FormValue("key2") != "value2" { t.Errorf("key2 = %v; want value2", r.FormValue("key2")) } w.WriteHeader(http.StatusOK) })) defer server.Close() resp, err := Post(server.URL, WithFormDataMap(map[string]string{ "key1": "value1", "key2": "value2", })) if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestWithFormData 测试 WithFormData func TestWithFormData(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.ParseForm() values := r.Form["tags"] if len(values) != 2 { t.Errorf("tags length = %v; want 2", len(values)) } w.WriteHeader(http.StatusOK) })) defer server.Close() resp, err := Post(server.URL, WithFormData(map[string][]string{ "tags": {"tag1", "tag2"}, })) if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestWithAddFormData 测试 WithAddFormData func TestWithAddFormData(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.ParseForm() if r.FormValue("name") != "test" { t.Errorf("name = %v; want test", r.FormValue("name")) } w.WriteHeader(http.StatusOK) })) defer server.Close() resp, err := Post(server.URL, WithAddFormData("name", "test")) if err != nil { t.Fatalf("Post() error: %v", err) } defer resp.Close() } // TestHeaderOperations 测试 Header 操作 func TestHeaderOperations(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { headers := make(map[string][]string) for k, v := range r.Header { headers[k] = v } json.NewEncoder(w).Encode(headers) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET") // AddHeader req.AddHeader("X-Multi", "value1") req.AddHeader("X-Multi", "value2") // SetHeader req.SetHeader("X-Single", "single-value") // DeleteHeader req.SetHeader("X-Delete", "will-be-deleted") req.DeleteHeader("X-Delete") // ResetHeaders req2 := NewSimpleRequest(server.URL, "GET") req2.SetHeader("X-Test", "test") req2.ResetHeaders() req2.SetHeader("X-After-Reset", "value") resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() var headers map[string][]string resp.Body().JSON(&headers) // 验证 AddHeader if len(headers["X-Multi"]) != 2 { t.Errorf("X-Multi length = %v; want 2", len(headers["X-Multi"])) } // 验证 SetHeader if headers["X-Single"][0] != "single-value" { t.Errorf("X-Single = %v; want single-value", headers["X-Single"][0]) } // 验证 DeleteHeader if _, exists := headers["X-Delete"]; exists { t.Error("X-Delete should be deleted") } // 测试 ResetHeaders resp2, err := req2.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp2.Close() var headers2 map[string][]string resp2.Body().JSON(&headers2) if _, exists := headers2["X-Test"]; exists { t.Error("X-Test should not exist after reset") } if headers2["X-After-Reset"][0] != "value" { t.Errorf("X-After-Reset = %v; want value", headers2["X-After-Reset"][0]) } } // TestCookieOperations 测试 Cookie 操作 func TestCookieOperations(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookies := make(map[string]string) for _, cookie := range r.Cookies() { cookies[cookie.Name] = cookie.Value } json.NewEncoder(w).Encode(cookies) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET") // AddSimpleCookie req.AddSimpleCookie("simple", "simple-value") // AddCookieKV req.AddCookieKV("custom", "custom-value", "/path") // AddCookie req.AddCookie(&http.Cookie{ Name: "full", Value: "full-value", Path: "/", }) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() var cookies map[string]string resp.Body().JSON(&cookies) if cookies["simple"] != "simple-value" { t.Errorf("simple = %v; want simple-value", cookies["simple"]) } if cookies["custom"] != "custom-value" { t.Errorf("custom = %v; want custom-value", cookies["custom"]) } if cookies["full"] != "full-value" { t.Errorf("full = %v; want full-value", cookies["full"]) } // 测试 ResetCookies req2 := NewSimpleRequest(server.URL, "GET") req2.AddSimpleCookie("before", "before-value") req2.ResetCookies() req2.AddSimpleCookie("after", "after-value") resp2, err := req2.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp2.Close() var cookies2 map[string]string resp2.Body().JSON(&cookies2) if _, exists := cookies2["before"]; exists { t.Error("before cookie should not exist after reset") } if cookies2["after"] != "after-value" { t.Errorf("after = %v; want after-value", cookies2["after"]) } } // TestQueryOperations 测试 Query 操作 func TestQueryOperations(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() result := make(map[string][]string) for k, v := range query { result[k] = v } json.NewEncoder(w).Encode(result) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET") // AddQuery req.AddQuery("multi", "value1") req.AddQuery("multi", "value2") // SetQuery req.SetQuery("single", "single-value") // AddQueries req.AddQueries(map[string]string{ "batch1": "batch-value1", "batch2": "batch-value2", }) // DeleteQuery req.AddQuery("delete-me", "will-be-deleted") req.DeleteQuery("delete-me") // DeleteQueryValue req.AddQuery("partial", "keep") req.AddQuery("partial", "delete") req.DeleteQueryValue("partial", "delete") resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() var result map[string][]string resp.Body().JSON(&result) // 验证 AddQuery if len(result["multi"]) != 2 { t.Errorf("multi length = %v; want 2", len(result["multi"])) } // 验证 SetQuery if len(result["single"]) != 1 || result["single"][0] != "single-value" { t.Errorf("single = %v; want [single-value]", result["single"]) } // 验证 AddQueries if result["batch1"][0] != "batch-value1" { t.Errorf("batch1 = %v; want batch-value1", result["batch1"][0]) } // 验证 DeleteQuery if _, exists := result["delete-me"]; exists { t.Error("delete-me should not exist") } // 验证 DeleteQueryValue if len(result["partial"]) != 1 || result["partial"][0] != "keep" { t.Errorf("partial = %v; want [keep]", result["partial"]) } } // TestWithCookies 测试 WithCookies func TestWithCookies(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookies := make(map[string]string) for _, cookie := range r.Cookies() { cookies[cookie.Name] = cookie.Value } json.NewEncoder(w).Encode(cookies) })) defer server.Close() resp, err := Get(server.URL, WithCookies(map[string]string{ "cookie1": "value1", "cookie2": "value2", })) if err != nil { t.Fatalf("Get() error: %v", err) } defer resp.Close() var cookies map[string]string resp.Body().JSON(&cookies) if cookies["cookie1"] != "value1" { t.Errorf("cookie1 = %v; want value1", cookies["cookie1"]) } if cookies["cookie2"] != "value2" { t.Errorf("cookie2 = %v; want value2", cookies["cookie2"]) } } // TestWithHeaders 测试 WithHeaders func TestWithHeaders(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-Header1") != "value1" { t.Errorf("X-Header1 = %v; want value1", r.Header.Get("X-Header1")) } if r.Header.Get("X-Header2") != "value2" { t.Errorf("X-Header2 = %v; want value2", r.Header.Get("X-Header2")) } w.WriteHeader(http.StatusOK) })) defer server.Close() resp, err := Get(server.URL, WithHeaders(map[string]string{ "X-Header1": "value1", "X-Header2": "value2", })) if err != nil { t.Fatalf("Get() error: %v", err) } defer resp.Close() } // TestWithQueries 测试 WithQueries func TestWithQueries(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() json.NewEncoder(w).Encode(query) })) defer server.Close() resp, err := Get(server.URL, WithQueries(map[string]string{ "key1": "value1", "key2": "value2", })) if err != nil { t.Fatalf("Get() error: %v", err) } defer resp.Close() var result map[string][]string resp.Body().JSON(&result) if result["key1"][0] != "value1" { t.Errorf("key1 = %v; want value1", result["key1"][0]) } if result["key2"][0] != "value2" { t.Errorf("key2 = %v; want value2", result["key2"][0]) } } // TestSetReferer 测试 SetReferer func TestSetReferer(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Referer() != "https://example.com" { t.Errorf("Referer = %v; want https://example.com", r.Referer()) } w.WriteHeader(http.StatusOK) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET"). SetReferer("https://example.com") resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() } // TestSetBearerToken 测试 SetBearerToken func TestSetBearerToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if auth != "Bearer test-token-123" { t.Errorf("Authorization = %v; want Bearer test-token-123", auth) } w.WriteHeader(http.StatusOK) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET"). SetBearerToken("test-token-123") resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() } // TestGetHeader 测试 GetHeader func TestGetHeader(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET") req.SetHeader("X-Test", "test-value") value := req.GetHeader("X-Test") if value != "test-value" { t.Errorf("GetHeader = %v; want test-value", value) } } // TestEnableDisableRawMode 测试 EnableRawMode 和 DisableRawMode func TestEnableDisableRawMode(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET") // 默认不是 raw 模式 if req.doRaw { t.Error("Request should not be in raw mode by default") } // 启用 raw 模式 req.EnableRawMode() if !req.doRaw { t.Error("EnableRawMode should enable raw mode") } // 禁用 raw 模式 req.DisableRawMode() if req.doRaw { t.Error("DisableRawMode should disable raw mode") } } // TestContextOperations 测试 Context 操作 func TestContextOperations(t *testing.T) { ctx := context.WithValue(context.Background(), "test-key", "test-value") req := NewSimpleRequest("http://example.com", "GET") req.SetContext(ctx) if req.Context() != ctx { t.Error("SetContext did not set context correctly") } // 验证 context 中的值 if req.Context().Value("test-key") != "test-value" { t.Error("Context value not preserved") } } // TestRawRequestOperations 测试 RawRequest 操作 func TestRawRequestOperations(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() rawReq, _ := http.NewRequest("GET", server.URL, nil) rawReq.Header.Set("X-Raw", "raw-value") req := NewSimpleRequest("", "GET") req.SetRawRequest(rawReq) if req.RawRequest() != rawReq { t.Error("SetRawRequest did not set raw request correctly") } resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() } // TestURLOperations 测试 URL 操作 func TestURLOperations(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET") if req.URL() != "http://example.com" { t.Errorf("URL() = %v; want http://example.com", req.URL()) } req.SetURL("http://newexample.com") if req.URL() != "http://newexample.com" { t.Errorf("URL() after SetURL = %v; want http://newexample.com", req.URL()) } } // TestMethodOperations 测试 Method 操作 func TestMethodOperations(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET") if req.Method() != "GET" { t.Errorf("Method() = %v; want GET", req.Method()) } req.SetMethod("POST") if req.Method() != "POST" { t.Errorf("Method() after SetMethod = %v; want POST", req.Method()) } } // ---- Client: SetDefaultTLSConfig / EnableRedirect / Options / NewClientFromHTTP ---- func TestClientSetDefaultTLSConfig(t *testing.T) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer ts.Close() c := NewClientNoErr() c.SetDefaultTLSConfig(&tls.Config{InsecureSkipVerify: true}) resp, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get() error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("StatusCode=%d", resp.StatusCode) } } func TestClientEnableRedirect(t *testing.T) { n := 0 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if n == 0 { n++ http.Redirect(w, r, "/ok", http.StatusFound) return } w.WriteHeader(http.StatusOK) })) defer s.Close() c := NewClientNoErr() c.DisableRedirect() resp, err := c.Get(s.URL) if err != nil { t.Fatalf("Get() error: %v", err) } resp.Close() if resp.StatusCode != http.StatusFound { t.Fatalf("want 302, got %d", resp.StatusCode) } c.EnableRedirect() resp2, err := c.Get(s.URL) if err != nil { t.Fatalf("Get() after EnableRedirect error: %v", err) } defer resp2.Close() if resp2.StatusCode != http.StatusOK { t.Fatalf("want 200, got %d", resp2.StatusCode) } } func TestClientOptionsMethod(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodOptions { t.Fatalf("method=%s", r.Method) } w.WriteHeader(http.StatusNoContent) })) defer s.Close() c := NewClientNoErr() resp, err := c.Options(s.URL) if err != nil { t.Fatalf("Options() error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusNoContent { t.Fatalf("status=%d", resp.StatusCode) } } func TestNewClientFromHTTP_WithConfiguredTransport(t *testing.T) { hc := &http.Client{ Transport: &http.Transport{ MaxIdleConns: 17, }, Timeout: 3 * time.Second, } c, err := NewClientFromHTTP(hc) if err != nil { t.Fatalf("NewClientFromHTTP error: %v", err) } if c == nil || c.HTTPClient() == nil { t.Fatal("client nil") } // 覆盖“http.Client 已有 *http.Transport 的包装路径” if _, ok := c.HTTPClient().Transport.(*Transport); !ok { t.Fatalf("transport not wrapped to *Transport, got %T", c.HTTPClient().Transport) } } // ---- context / getRequestContext 覆盖缺口 ---- func TestGetRequestContext_AllMissingBranches(t *testing.T) { dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil } tr := &http.Transport{} ctx := context.Background() ctx = context.WithValue(ctx, ctxKeyTransport, tr) ctx = context.WithValue(ctx, ctxKeyProxy, "http://127.0.0.1:29992") ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string{"8.8.8.8"}) ctx = context.WithValue(ctx, ctxKeyDialFunc, dialFn) rc := getRequestContext(ctx) if rc.Transport != tr { t.Fatal("transport not extracted") } if rc.Proxy != "http://127.0.0.1:29992" { t.Fatal("proxy not extracted") } if len(rc.CustomDNS) != 1 || rc.CustomDNS[0] != "8.8.8.8" { t.Fatal("custom dns not extracted") } if rc.DialFn == nil { t.Fatal("dialFn not extracted") } } // ---- 默认函数: put/delete/patch/options/trace/connect ---- func TestDefaultMethodsCoverage(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.Method)) })) defer s.Close() cases := []struct { name string fn func(string, ...RequestOpt) (*Response, error) want string }{ {"PUT", Put, http.MethodPut}, {"DELETE", Delete, http.MethodDelete}, {"PATCH", Patch, http.MethodPatch}, {"OPTIONS", Options, http.MethodOptions}, {"TRACE", Trace, http.MethodTrace}, {"CONNECT", Connect, http.MethodConnect}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { resp, err := tc.fn(s.URL) if err != nil { t.Fatalf("%s error: %v", tc.name, err) } defer resp.Close() body, _ := resp.Body().String() if body != tc.want { t.Fatalf("body=%q want=%q", body, tc.want) } }) } } // ---- Request: SetQueries / SetTransport / SetAutoCalcContentLength / SetContentLength ---- func TestRequestSetQueries(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() if q.Get("a") != "1" || q.Get("b") != "2" { t.Fatalf("query not set: %v", q) } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, "GET"). SetQueries(map[string][]string{"a": {"1"}, "b": {"2"}}) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } func TestRequestSetTransport(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer s.Close() base := &http.Transport{} req := NewSimpleRequest(s.URL, "GET").SetTransport(base) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } func TestRequestSetAutoCalcContentLength(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength <= 0 { t.Fatalf("content-length not auto calculated: %d", r.ContentLength) } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, "POST"). SetBodyReader(stringsNewReaderCompat("hello-autocalc")). SetAutoCalcContentLength(true) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } func TestRequestSetContentLength(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength != 5 { t.Fatalf("content-length=%d", r.ContentLength) } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, "POST"). SetBody([]byte("hello")). SetContentLength(5) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } // ---- Request: AddCustomDNS / AddCustomIP / SetDialFunc ---- func TestRequestAddCustomDNSAndIP(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET"). AddCustomDNS("8.8.8.8"). AddCustomIP("1.1.1.1") if req.Err() != nil { t.Fatalf("unexpected err: %v", req.Err()) } if len(req.config.DNS.CustomDNS) != 1 || req.config.DNS.CustomDNS[0] != "8.8.8.8" { t.Fatal("custom dns not added") } if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.1.1.1" { t.Fatal("custom ip not added") } } func TestRequestSetDialFunc(t *testing.T) { called := false fn := func(ctx context.Context, network, addr string) (net.Conn, error) { called = true return nil, io.EOF } req := NewSimpleRequest("http://example.com", "GET").SetDialFunc(fn) if req.config.Network.DialFunc == nil { t.Fatal("dial func not set") } _, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1") if !called { t.Fatal("dial func not callable") } } // ---- Request header/cookie bulk APIs ---- func TestRequestSetHeadersAndAddHeaders(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-A") != "1" || r.Header.Get("X-B") != "2" || r.Header.Get("X-C") != "3" { t.Fatalf("headers not correct: %v", r.Header) } w.WriteHeader(http.StatusOK) })) defer s.Close() h := http.Header{} h.Set("X-A", "1") h.Set("X-B", "2") req := NewSimpleRequest(s.URL, "GET"). SetHeaders(h). AddHeaders(map[string]string{"X-C": "3"}) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } func TestRequestSetCookiesAndAddCookies(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got := map[string]string{} for _, c := range r.Cookies() { got[c.Name] = c.Value } if got["a"] != "1" || got["b"] != "2" || got["c"] != "3" { t.Fatalf("cookies=%v", got) } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, "GET"). SetCookies([]*http.Cookie{ {Name: "a", Value: "1", Path: "/"}, {Name: "b", Value: "2", Path: "/"}, }). AddCookies(map[string]string{"c": "3"}) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() } // ---- Body.Close / Response.CloseWithClient ---- func TestBodyClose(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) defer s.Close() resp, err := Get(s.URL) if err != nil { t.Fatalf("Get() error: %v", err) } // 直接测 Body.Close if err := resp.Body().Close(); err != nil { t.Fatalf("Body.Close() error: %v", err) } } func TestResponseCloseWithClient(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) })) defer s.Close() resp, err := Get(s.URL) if err != nil { t.Fatalf("Get() error: %v", err) } if err := resp.CloseWithClient(); err != nil { t.Fatalf("CloseWithClient() error: %v", err) } } // 小兼容函数,避免你当前文件没引 strings 包时报错(可直接替换成 strings.NewReader) func stringsNewReaderCompat(s string) io.Reader { return io.NopCloser(io.MultiReader(io.LimitReader(io.NopCloser(stringsReader(s)), int64(len(s))))) } // 纯标准库最小 reader type stringsReader string func (sr stringsReader) Read(p []byte) (int, error) { if len(sr) == 0 { return 0, io.EOF } n := copy(p, []byte(sr)) return n, nil }