diff --git a/addon_test.go b/addon_test.go new file mode 100644 index 0000000..dfa301b --- /dev/null +++ b/addon_test.go @@ -0,0 +1,1657 @@ +package starnet + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestComplexScenario1_RequestLevelConfigOverride 测试请求级配置覆盖 Client 级配置 +func TestComplexScenario1_RequestLevelConfigOverride(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(150 * time.Millisecond) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Client 级别:5 秒超时 + client := NewClientNoErr(WithTimeout(5 * time.Second)) + + // 请求 1:使用 Client 的超时(应该成功) + resp1, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Request 1 error: %v", err) + } + resp1.Close() + + // 请求 2:请求级别覆盖为 100ms(应该超时) + start := time.Now() + _, err = client.Get(server.URL, WithTimeout(100*time.Millisecond)) + elapsed := time.Since(start) + + if err == nil { + t.Error("Request 2 should timeout, got nil error") + } + + if elapsed > 500*time.Millisecond { + t.Errorf("Request 2 timeout took too long: %v", elapsed) + } + + // 请求 3:再次使用 Client 的超时(应该成功,验证没有副作用) + resp3, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Request 3 error: %v", err) + } + resp3.Close() +} + +// TestComplexScenario2_TLSConfigPriority 测试 TLS 配置的优先级 +func TestComplexScenario2_TLSConfigPriority(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // 场景 1:Client 级别设置 SkipVerify + client := NewClientNoErr() + client.SetDefaultSkipTLSVerify(true) + + resp1, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Scenario 1 error: %v", err) + } + resp1.Close() + + // 场景 2:请求级别设置自定义 TLS Config(应该覆盖 Client 级别) + customTLS := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + } + + resp2, err := client.Get(server.URL, WithTLSConfig(customTLS)) + if err != nil { + t.Fatalf("Scenario 2 error: %v", err) + } + resp2.Close() + + // 场景 3:请求级别只设置 SkipVerify(不设置完整 TLS Config) + resp3, err := client.Get(server.URL, WithSkipTLSVerify(true)) + if err != nil { + t.Fatalf("Scenario 3 error: %v", err) + } + resp3.Close() + + // 场景 4:新 Client 不设置任何 TLS 配置(应该失败) + client2 := NewClientNoErr() + _, err = client2.Get(server.URL) + if err == nil { + t.Error("Scenario 4 should fail with TLS error, got nil") + } +} + +// TestComplexScenario3_ConnectionPoolReuse 测试连接池复用 +func TestComplexScenario3_ConnectionPoolReuse(t *testing.T) { + var connCount int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&connCount, 1) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + client := NewClientNoErr() + + // 发送 10 个请求,应该复用连接 + for i := 0; i < 10; i++ { + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Request %d error: %v", i, err) + } + // 必须读取并关闭 body 才能复用连接 + io.ReadAll(resp.Body().raw) + resp.Close() + } + + // 验证连接被复用(实际连接数应该远小于请求数) + // 注意:这个测试可能不稳定,因为连接池行为依赖于时间和系统状态 + t.Logf("Total handler calls: %d", atomic.LoadInt64(&connCount)) +} + +// TestComplexScenario4_CustomDNSWithFallback 测试自定义 DNS 和回退机制 +func TestComplexScenario4_CustomDNSWithFallback(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // 提取服务器的实际 IP 和端口 + serverURL := server.URL + host := strings.TrimPrefix(serverURL, "http://") + + // 场景 1:使用自定义 IP(直接指定) + parts := strings.Split(host, ":") + if len(parts) != 2 { + t.Fatalf("Invalid server URL: %s", serverURL) + } + ip := parts[0] + port := parts[1] + + // 构造一个使用域名的 URL + testURL := fmt.Sprintf("http://test.example.com:%s", port) + + req := NewSimpleRequest(testURL, "GET").SetCustomIP([]string{ip}) + resp, err := req.Do() + if err != nil { + t.Fatalf("Custom IP request error: %v", err) + } + resp.Close() + + // 场景 2:使用自定义 DNS 解析函数 + lookupCalled := false + customLookup := func(ctx context.Context, host string) ([]net.IPAddr, error) { + lookupCalled = true + // 返回实际的 IP + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + } + + req2 := NewSimpleRequest(testURL, "GET").SetLookupFunc(customLookup) + resp2, err := req2.Do() + if err != nil { + t.Fatalf("Custom lookup request error: %v", err) + } + resp2.Close() + + if !lookupCalled { + t.Error("Custom lookup function was not called") + } +} + +// TestComplexScenario5_ConcurrentRequestsWithDifferentConfigs 测试并发请求使用不同配置 +func TestComplexScenario5_ConcurrentRequestsWithDifferentConfigs(t *testing.T) { + // 创建多个服务器,模拟不同的延迟 + servers := make([]*httptest.Server, 3) + for i := range servers { + delay := time.Duration(i*50) * time.Millisecond + idx := i // ← 修复:创建局部变量 + servers[i] = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(delay) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf("Server %d", idx))) // ← 使用局部变量 + })) + defer servers[i].Close() + } + + client := NewClientNoErr() + + var wg sync.WaitGroup + results := make([]string, 3) + errors := make([]error, 3) + + // 并发发送请求,每个请求使用不同的超时 + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + timeout := time.Duration((idx+1)*100) * time.Millisecond + resp, err := client.Get(servers[idx].URL, WithTimeout(timeout)) + if err != nil { + errors[idx] = err + return + } + defer resp.Close() + + body, _ := resp.Body().String() + results[idx] = body + }(i) + } + + wg.Wait() + + // 验证结果 + for i := 0; i < 3; i++ { + if errors[i] != nil { + t.Errorf("Request %d error: %v", i, errors[i]) + } + expected := fmt.Sprintf("Server %d", i) + if results[i] != expected { + t.Errorf("Request %d result = %v; want %v", i, results[i], expected) + } + } +} + +// TestComplexScenario6_RequestCloneIndependence 测试克隆请求的独立性 +func TestComplexScenario6_RequestCloneIndependence(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 返回所有 headers + for k, v := range r.Header { + w.Header().Set(k, strings.Join(v, ",")) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // 创建基础请求 + baseReq := NewSimpleRequest(server.URL, "GET"). + SetHeader("X-Base", "base-value"). + SetTimeout(5 * time.Second) + + // 克隆并修改 + req1 := baseReq.Clone(). + SetHeader("X-Request", "request-1"). + SetTimeout(1 * time.Second) + + req2 := baseReq.Clone(). + SetHeader("X-Request", "request-2"). + SetTimeout(2 * time.Second) + + // 执行请求 + resp1, err := req1.Do() + if err != nil { + t.Fatalf("Request 1 error: %v", err) + } + defer resp1.Close() + + resp2, err := req2.Do() + if err != nil { + t.Fatalf("Request 2 error: %v", err) + } + defer resp2.Close() + + // 验证 headers 独立 + if resp1.Header.Get("X-Request") != "request-1" { + t.Errorf("Request 1 header = %v; want request-1", resp1.Header.Get("X-Request")) + } + + if resp2.Header.Get("X-Request") != "request-2" { + t.Errorf("Request 2 header = %v; want request-2", resp2.Header.Get("X-Request")) + } + + // 验证基础请求未被修改 + resp3, err := baseReq.Do() + if err != nil { + t.Fatalf("Base request error: %v", err) + } + defer resp3.Close() + + if resp3.Header.Get("X-Request") != "" { + t.Errorf("Base request should not have X-Request header, got %v", resp3.Header.Get("X-Request")) + } +} + +// TestComplexScenario7_ErrorAccumulation 测试错误累积机制 +func TestComplexScenario7_ErrorAccumulation(t *testing.T) { + // 场景 1:链式调用中的错误累积 + req := NewSimpleRequest("://invalid-url", "GET"). + SetHeader("X-Test", "value"). + AddQuery("key", "value") + + // 错误应该被累积,不会 panic + if req.Err() == nil { + t.Error("Expected error for invalid URL, got nil") + } + + // 后续操作应该被忽略 + req.SetTimeout(5 * time.Second) + + // Do() 应该返回累积的错误 + _, err := req.Do() + if err == nil { + t.Error("Do() should return accumulated error, got nil") + } + + // 场景 2:无效的方法 + req2 := NewSimpleRequest("http://example.com", "INVALID METHOD!") + if req2.Err() == nil { + t.Error("Expected error for invalid method, got nil") + } + + // 场景 3:无效的 IP + req3 := NewSimpleRequest("http://example.com", "GET"). + SetCustomIP([]string{"invalid-ip"}) + + if req3.Err() == nil { + t.Error("Expected error for invalid IP, got nil") + } +} + +// TestComplexScenario8_DialTimeoutVsRequestTimeout 测试 DialTimeout 和 Timeout 的区别 +func TestComplexScenario8_DialTimeoutVsRequestTimeout(t *testing.T) { + // 场景 1:DialTimeout - 连接超时 + start := time.Now() + req := NewSimpleRequest("http://192.0.2.1:80", "GET"). + SetDialTimeout(100 * time.Millisecond) + + _, err := req.Do() + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected dial timeout error, got nil") + } + + if elapsed > 2*time.Second { + t.Errorf("Dial timeout took too long: %v", elapsed) + } + + // 场景 2:Timeout - 总超时(包括响应读取) + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer slowServer.Close() + + start2 := time.Now() + req2 := NewSimpleRequest(slowServer.URL, "GET"). + SetTimeout(100 * time.Millisecond) + + _, err2 := req2.Do() + elapsed2 := time.Since(start2) + + if err2 == nil { + t.Error("Expected request timeout error, got nil") + } + + if elapsed2 > 500*time.Millisecond { + t.Errorf("Request timeout took too long: %v", elapsed2) + } +} + +// TestComplexScenario9_MultipartUploadWithProgress 测试带进度的文件上传 +func TestComplexScenario9_MultipartUploadWithProgress(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(10 << 20) + if err != nil { + t.Errorf("ParseMultipartForm error: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + // 验证表单字段 + if r.FormValue("name") != "test" { + t.Errorf("name = %v; want test", r.FormValue("name")) + } + + // 验证文件 + file, header, err := r.FormFile("file") + if err != nil { + t.Errorf("FormFile error: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + defer file.Close() + + content, _ := io.ReadAll(file) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf("Received: %s (%d bytes)", header.Filename, len(content)))) + })) + defer server.Close() + + // 创建测试数据 + fileContent := strings.Repeat("test data ", 1000) // ~10KB + reader := strings.NewReader(fileContent) + + // 跟踪进度 + var progressCalls int64 + var lastUploaded int64 + + req := NewSimpleRequest(server.URL, "POST"). + AddFormData("name", "test"). + AddFileStream("file", "test.txt", int64(len(fileContent)), reader). + SetUploadProgress(func(filename string, uploaded, total int64) { + atomic.AddInt64(&progressCalls, 1) + atomic.StoreInt64(&lastUploaded, uploaded) + + if filename != "test.txt" { + t.Errorf("filename = %v; want test.txt", filename) + } + if total != int64(len(fileContent)) { + t.Errorf("total = %v; want %v", total, len(fileContent)) + } + }) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Upload error: %v", err) + } + defer resp.Close() + + // 验证进度回调被调用 + if atomic.LoadInt64(&progressCalls) == 0 { + t.Error("Progress callback was not called") + } + + // 验证最终上传量 + if atomic.LoadInt64(&lastUploaded) != int64(len(fileContent)) { + t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent)) + } + + body, _ := resp.Body().String() + if !strings.Contains(body, "test.txt") { + t.Errorf("Response should contain filename, got: %v", body) + } +} + +// TestComplexScenario10_ClientCloneWithOptions 测试 Client 克隆和选项继承 +func TestComplexScenario10_ClientCloneWithOptions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.Header.Get("X-Client-ID"))) + })) + defer server.Close() + + // 创建带选项的 Client + client1 := NewClientNoErr( + WithTimeout(5*time.Second), + WithHeader("X-Client-ID", "client-1"), + ) + + // 克隆 Client + client2 := client1.Clone() + client2.AddOptions(WithHeader("X-Extra", "extra-value")) + + // 测试 client1 + resp1, err := client1.Get(server.URL) + if err != nil { + t.Fatalf("Client 1 error: %v", err) + } + defer resp1.Close() + + body1, _ := resp1.Body().String() + if body1 != "client-1" { + t.Errorf("Client 1 response = %v; want client-1", body1) + } + + // 测试 client2(应该继承 client1 的选项) + resp2, err := client2.Get(server.URL) + if err != nil { + t.Fatalf("Client 2 error: %v", err) + } + defer resp2.Close() + + body2, _ := resp2.Body().String() + if body2 != "client-1" { + t.Errorf("Client 2 response = %v; want client-1", body2) + } + + // 验证 client1 未被修改 + opts1 := client1.RequestOptions() + opts2 := client2.RequestOptions() + + if len(opts1) >= len(opts2) { + t.Errorf("Client 2 should have more options than Client 1") + } +} + +// TestComplexScenario11_ContextCancellation 测试 Context 取消 +func TestComplexScenario11_ContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // 在 500ms 后取消 + go func() { + time.Sleep(500 * time.Millisecond) + cancel() + }() + + req := NewSimpleRequestWithContext(ctx, server.URL, "GET") + + start := time.Now() + _, err := req.Do() + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected context cancellation error, got nil") + } + + if elapsed > 1*time.Second { + t.Errorf("Context cancellation took too long: %v", elapsed) + } +} + +// TestComplexScenario12_RedirectWithCookies 测试重定向时的 Cookie 处理 +func TestComplexScenario12_RedirectWithCookies(t *testing.T) { + var redirectCount int + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if redirectCount < 2 { + redirectCount++ + // 设置 Cookie 并重定向 + http.SetCookie(w, &http.Cookie{ + Name: fmt.Sprintf("cookie%d", redirectCount), + Value: fmt.Sprintf("value%d", redirectCount), + Path: "/", + }) + http.Redirect(w, r, "/final", http.StatusFound) + return + } + + // 最终响应 + w.WriteHeader(http.StatusOK) + w.Write([]byte("final")) + })) + defer server.Close() + + // 测试自动跟随重定向 + client := NewClientNoErr() + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } + + body, _ := resp.Body().String() + if body != "final" { + t.Errorf("Body = %v; want final", body) + } + + // 测试禁用重定向 + redirectCount = 0 + client.DisableRedirect() + + resp2, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get error: %v", err) + } + defer resp2.Close() + + if resp2.StatusCode != http.StatusFound { + t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound) + } + + // 验证 Set-Cookie + cookies := resp2.Cookies() + if len(cookies) == 0 { + t.Error("Expected cookies in redirect response") + } +} + +// TestDefaultsSetDefaultClient 测试设置默认 Client +func TestDefaultsSetDefaultClient(t *testing.T) { + // 保存原始的默认 Client + originalClient := DefaultClient() + + // 创建自定义 Client + customClient := NewClientNoErr(WithTimeout(1 * time.Second)) + SetDefaultClient(customClient) + + // 验证默认 Client 已更改 + if DefaultClient() != customClient { + t.Error("SetDefaultClient did not update default client") + } + + // 恢复原始 Client + SetDefaultClient(originalClient) +} + +// TestDefaultsSetDefaultHTTPClient 测试设置默认 HTTP Client +func TestDefaultsSetDefaultHTTPClient(t *testing.T) { + // 保存原始的默认 HTTP Client + originalHTTPClient := DefaultHTTPClient() + + // 创建自定义 HTTP Client + customHTTPClient := &http.Client{ + Timeout: 2 * time.Second, + } + SetDefaultHTTPClient(customHTTPClient) + + // 验证默认 HTTP Client 已更改 + if DefaultHTTPClient() != customHTTPClient { + t.Error("SetDefaultHTTPClient did not update default http client") + } + + // 恢复原始 HTTP Client + SetDefaultHTTPClient(originalHTTPClient) +} + +// TestDefaultsHeadMethod 测试 Head 方法 +func TestDefaultsHeadMethod(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + t.Errorf("Method = %v; want HEAD", r.Method) + } + w.Header().Set("X-Custom", "test-value") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := Head(server.URL) + if err != nil { + t.Fatalf("Head() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } + + // HEAD 请求应该有 headers 但没有 body + if resp.Header.Get("X-Custom") != "test-value" { + t.Errorf("Header X-Custom = %v; want test-value", resp.Header.Get("X-Custom")) + } +} + +// TestProxyConfiguration 测试代理配置 +func TestProxyConfiguration(t *testing.T) { + // 创建目标服务器 + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("target")) + })) + defer targetServer.Close() + + // 创建代理服务器 + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 简单的代理逻辑 + w.Header().Set("X-Proxied", "true") + w.WriteHeader(http.StatusOK) + w.Write([]byte("proxied")) + })) + defer proxyServer.Close() + + // 测试 WithProxy + req := NewSimpleRequest(targetServer.URL, "GET").SetProxy(proxyServer.URL) + + // 验证代理配置被设置 + if req.config.Network.Proxy != proxyServer.URL { + t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL) + } + + // 注意:实际的代理测试需要真实的代理服务器 + // 这里只验证配置是否正确设置 +} + +// TestWithRawRequest 测试 WithRawRequest +func TestWithRawRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Custom") != "raw-value" { + t.Errorf("X-Custom header = %v; want raw-value", r.Header.Get("X-Custom")) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // 创建原始 http.Request + rawReq, _ := http.NewRequest("GET", server.URL, nil) + rawReq.Header.Set("X-Custom", "raw-value") + + // 使用 WithRawRequest + req := NewSimpleRequest("", "GET", WithRawRequest(rawReq)) + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "OK" { + t.Errorf("Body = %v; want OK", body) + } +} + +// TestWithContentLength 测试 WithContentLength +func TestWithContentLength(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength != 9 { + t.Errorf("ContentLength = %v; want 9", r.ContentLength) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + data := []byte("test data") + resp, err := Post(server.URL, + WithBody(data), + WithContentLength(int64(len(data)))) // 一致 + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestWithAutoCalcContentLength 测试自动计算 Content-Length +func TestWithAutoCalcContentLength(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证 Content-Length 被正确设置 + if r.ContentLength <= 0 { + t.Errorf("ContentLength = %v; want > 0", r.ContentLength) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + data := strings.NewReader("test data for auto calc") + resp, err := Post(server.URL, + WithBodyReader(data), + WithAutoCalcContentLength(true)) + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestChunkedTransferEncoding 测试 Chunked 传输编码 +func TestChunkedTransferEncoding(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证使用了 chunked 编码 + if len(r.TransferEncoding) > 0 && r.TransferEncoding[0] == "chunked" { + w.Header().Set("X-Chunked", "true") + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + data := []byte("test data") + resp, err := Post(server.URL, + WithBody(data), + WithContentLength(-1)) // -1 强制使用 chunked + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestWithFormDataMap 测试 WithFormDataMap +func TestWithFormDataMap(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if r.FormValue("key1") != "value1" { + t.Errorf("key1 = %v; want value1", r.FormValue("key1")) + } + if r.FormValue("key2") != "value2" { + t.Errorf("key2 = %v; want value2", r.FormValue("key2")) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := Post(server.URL, + WithFormDataMap(map[string]string{ + "key1": "value1", + "key2": "value2", + })) + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestWithFormData 测试 WithFormData +func TestWithFormData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + values := r.Form["tags"] + if len(values) != 2 { + t.Errorf("tags length = %v; want 2", len(values)) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := Post(server.URL, + WithFormData(map[string][]string{ + "tags": {"tag1", "tag2"}, + })) + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestWithAddFormData 测试 WithAddFormData +func TestWithAddFormData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if r.FormValue("name") != "test" { + t.Errorf("name = %v; want test", r.FormValue("name")) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := Post(server.URL, + WithAddFormData("name", "test")) + + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() +} + +// TestHeaderOperations 测试 Header 操作 +func TestHeaderOperations(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers := make(map[string][]string) + for k, v := range r.Header { + headers[k] = v + } + json.NewEncoder(w).Encode(headers) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET") + + // AddHeader + req.AddHeader("X-Multi", "value1") + req.AddHeader("X-Multi", "value2") + + // SetHeader + req.SetHeader("X-Single", "single-value") + + // DeleteHeader + req.SetHeader("X-Delete", "will-be-deleted") + req.DeleteHeader("X-Delete") + + // ResetHeaders + req2 := NewSimpleRequest(server.URL, "GET") + req2.SetHeader("X-Test", "test") + req2.ResetHeaders() + req2.SetHeader("X-After-Reset", "value") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var headers map[string][]string + resp.Body().JSON(&headers) + + // 验证 AddHeader + if len(headers["X-Multi"]) != 2 { + t.Errorf("X-Multi length = %v; want 2", len(headers["X-Multi"])) + } + + // 验证 SetHeader + if headers["X-Single"][0] != "single-value" { + t.Errorf("X-Single = %v; want single-value", headers["X-Single"][0]) + } + + // 验证 DeleteHeader + if _, exists := headers["X-Delete"]; exists { + t.Error("X-Delete should be deleted") + } + + // 测试 ResetHeaders + resp2, err := req2.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp2.Close() + + var headers2 map[string][]string + resp2.Body().JSON(&headers2) + + if _, exists := headers2["X-Test"]; exists { + t.Error("X-Test should not exist after reset") + } + if headers2["X-After-Reset"][0] != "value" { + t.Errorf("X-After-Reset = %v; want value", headers2["X-After-Reset"][0]) + } +} + +// TestCookieOperations 测试 Cookie 操作 +func TestCookieOperations(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookies := make(map[string]string) + for _, cookie := range r.Cookies() { + cookies[cookie.Name] = cookie.Value + } + json.NewEncoder(w).Encode(cookies) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET") + + // AddSimpleCookie + req.AddSimpleCookie("simple", "simple-value") + + // AddCookieKV + req.AddCookieKV("custom", "custom-value", "/path") + + // AddCookie + req.AddCookie(&http.Cookie{ + Name: "full", + Value: "full-value", + Path: "/", + }) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var cookies map[string]string + resp.Body().JSON(&cookies) + + if cookies["simple"] != "simple-value" { + t.Errorf("simple = %v; want simple-value", cookies["simple"]) + } + if cookies["custom"] != "custom-value" { + t.Errorf("custom = %v; want custom-value", cookies["custom"]) + } + if cookies["full"] != "full-value" { + t.Errorf("full = %v; want full-value", cookies["full"]) + } + + // 测试 ResetCookies + req2 := NewSimpleRequest(server.URL, "GET") + req2.AddSimpleCookie("before", "before-value") + req2.ResetCookies() + req2.AddSimpleCookie("after", "after-value") + + resp2, err := req2.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp2.Close() + + var cookies2 map[string]string + resp2.Body().JSON(&cookies2) + + if _, exists := cookies2["before"]; exists { + t.Error("before cookie should not exist after reset") + } + if cookies2["after"] != "after-value" { + t.Errorf("after = %v; want after-value", cookies2["after"]) + } +} + +// TestQueryOperations 测试 Query 操作 +func TestQueryOperations(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + result := make(map[string][]string) + for k, v := range query { + result[k] = v + } + json.NewEncoder(w).Encode(result) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET") + + // AddQuery + req.AddQuery("multi", "value1") + req.AddQuery("multi", "value2") + + // SetQuery + req.SetQuery("single", "single-value") + + // AddQueries + req.AddQueries(map[string]string{ + "batch1": "batch-value1", + "batch2": "batch-value2", + }) + + // DeleteQuery + req.AddQuery("delete-me", "will-be-deleted") + req.DeleteQuery("delete-me") + + // DeleteQueryValue + req.AddQuery("partial", "keep") + req.AddQuery("partial", "delete") + req.DeleteQueryValue("partial", "delete") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var result map[string][]string + resp.Body().JSON(&result) + + // 验证 AddQuery + if len(result["multi"]) != 2 { + t.Errorf("multi length = %v; want 2", len(result["multi"])) + } + + // 验证 SetQuery + if len(result["single"]) != 1 || result["single"][0] != "single-value" { + t.Errorf("single = %v; want [single-value]", result["single"]) + } + + // 验证 AddQueries + if result["batch1"][0] != "batch-value1" { + t.Errorf("batch1 = %v; want batch-value1", result["batch1"][0]) + } + + // 验证 DeleteQuery + if _, exists := result["delete-me"]; exists { + t.Error("delete-me should not exist") + } + + // 验证 DeleteQueryValue + if len(result["partial"]) != 1 || result["partial"][0] != "keep" { + t.Errorf("partial = %v; want [keep]", result["partial"]) + } +} + +// TestWithCookies 测试 WithCookies +func TestWithCookies(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookies := make(map[string]string) + for _, cookie := range r.Cookies() { + cookies[cookie.Name] = cookie.Value + } + json.NewEncoder(w).Encode(cookies) + })) + defer server.Close() + + resp, err := Get(server.URL, + WithCookies(map[string]string{ + "cookie1": "value1", + "cookie2": "value2", + })) + + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var cookies map[string]string + resp.Body().JSON(&cookies) + + if cookies["cookie1"] != "value1" { + t.Errorf("cookie1 = %v; want value1", cookies["cookie1"]) + } + if cookies["cookie2"] != "value2" { + t.Errorf("cookie2 = %v; want value2", cookies["cookie2"]) + } +} + +// TestWithHeaders 测试 WithHeaders +func TestWithHeaders(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Header1") != "value1" { + t.Errorf("X-Header1 = %v; want value1", r.Header.Get("X-Header1")) + } + if r.Header.Get("X-Header2") != "value2" { + t.Errorf("X-Header2 = %v; want value2", r.Header.Get("X-Header2")) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := Get(server.URL, + WithHeaders(map[string]string{ + "X-Header1": "value1", + "X-Header2": "value2", + })) + + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() +} + +// TestWithQueries 测试 WithQueries +func TestWithQueries(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + json.NewEncoder(w).Encode(query) + })) + defer server.Close() + + resp, err := Get(server.URL, + WithQueries(map[string]string{ + "key1": "value1", + "key2": "value2", + })) + + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var result map[string][]string + resp.Body().JSON(&result) + + if result["key1"][0] != "value1" { + t.Errorf("key1 = %v; want value1", result["key1"][0]) + } + if result["key2"][0] != "value2" { + t.Errorf("key2 = %v; want value2", result["key2"][0]) + } +} + +// TestSetReferer 测试 SetReferer +func TestSetReferer(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Referer() != "https://example.com" { + t.Errorf("Referer = %v; want https://example.com", r.Referer()) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetReferer("https://example.com") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() +} + +// TestSetBearerToken 测试 SetBearerToken +func TestSetBearerToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token-123" { + t.Errorf("Authorization = %v; want Bearer test-token-123", auth) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetBearerToken("test-token-123") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() +} + +// TestGetHeader 测试 GetHeader +func TestGetHeader(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + req.SetHeader("X-Test", "test-value") + + value := req.GetHeader("X-Test") + if value != "test-value" { + t.Errorf("GetHeader = %v; want test-value", value) + } +} + +// TestEnableDisableRawMode 测试 EnableRawMode 和 DisableRawMode +func TestEnableDisableRawMode(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + + // 默认不是 raw 模式 + if req.doRaw { + t.Error("Request should not be in raw mode by default") + } + + // 启用 raw 模式 + req.EnableRawMode() + if !req.doRaw { + t.Error("EnableRawMode should enable raw mode") + } + + // 禁用 raw 模式 + req.DisableRawMode() + if req.doRaw { + t.Error("DisableRawMode should disable raw mode") + } +} + +// TestContextOperations 测试 Context 操作 +func TestContextOperations(t *testing.T) { + ctx := context.WithValue(context.Background(), "test-key", "test-value") + + req := NewSimpleRequest("http://example.com", "GET") + req.SetContext(ctx) + + if req.Context() != ctx { + t.Error("SetContext did not set context correctly") + } + + // 验证 context 中的值 + if req.Context().Value("test-key") != "test-value" { + t.Error("Context value not preserved") + } +} + +// TestRawRequestOperations 测试 RawRequest 操作 +func TestRawRequestOperations(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rawReq, _ := http.NewRequest("GET", server.URL, nil) + rawReq.Header.Set("X-Raw", "raw-value") + + req := NewSimpleRequest("", "GET") + req.SetRawRequest(rawReq) + + if req.RawRequest() != rawReq { + t.Error("SetRawRequest did not set raw request correctly") + } + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() +} + +// TestURLOperations 测试 URL 操作 +func TestURLOperations(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + + if req.URL() != "http://example.com" { + t.Errorf("URL() = %v; want http://example.com", req.URL()) + } + + req.SetURL("http://newexample.com") + if req.URL() != "http://newexample.com" { + t.Errorf("URL() after SetURL = %v; want http://newexample.com", req.URL()) + } +} + +// TestMethodOperations 测试 Method 操作 +func TestMethodOperations(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + + if req.Method() != "GET" { + t.Errorf("Method() = %v; want GET", req.Method()) + } + + req.SetMethod("POST") + if req.Method() != "POST" { + t.Errorf("Method() after SetMethod = %v; want POST", req.Method()) + } +} + +// ---- Client: SetDefaultTLSConfig / EnableRedirect / Options / NewClientFromHTTP ---- + +func TestClientSetDefaultTLSConfig(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + c := NewClientNoErr() + c.SetDefaultTLSConfig(&tls.Config{InsecureSkipVerify: true}) + + resp, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("StatusCode=%d", resp.StatusCode) + } +} + +func TestClientEnableRedirect(t *testing.T) { + n := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if n == 0 { + n++ + http.Redirect(w, r, "/ok", http.StatusFound) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + c := NewClientNoErr() + c.DisableRedirect() + resp, err := c.Get(s.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + resp.Close() + if resp.StatusCode != http.StatusFound { + t.Fatalf("want 302, got %d", resp.StatusCode) + } + + c.EnableRedirect() + resp2, err := c.Get(s.URL) + if err != nil { + t.Fatalf("Get() after EnableRedirect error: %v", err) + } + defer resp2.Close() + if resp2.StatusCode != http.StatusOK { + t.Fatalf("want 200, got %d", resp2.StatusCode) + } +} + +func TestClientOptionsMethod(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodOptions { + t.Fatalf("method=%s", r.Method) + } + w.WriteHeader(http.StatusNoContent) + })) + defer s.Close() + + c := NewClientNoErr() + resp, err := c.Options(s.URL) + if err != nil { + t.Fatalf("Options() error: %v", err) + } + defer resp.Close() + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("status=%d", resp.StatusCode) + } +} + +func TestNewClientFromHTTP_WithConfiguredTransport(t *testing.T) { + hc := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 17, + }, + Timeout: 3 * time.Second, + } + c, err := NewClientFromHTTP(hc) + if err != nil { + t.Fatalf("NewClientFromHTTP error: %v", err) + } + if c == nil || c.HTTPClient() == nil { + t.Fatal("client nil") + } + // 覆盖“http.Client 已有 *http.Transport 的包装路径” + if _, ok := c.HTTPClient().Transport.(*Transport); !ok { + t.Fatalf("transport not wrapped to *Transport, got %T", c.HTTPClient().Transport) + } +} + +// ---- context / getRequestContext 覆盖缺口 ---- + +func TestGetRequestContext_AllMissingBranches(t *testing.T) { + dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil } + tr := &http.Transport{} + ctx := context.Background() + ctx = context.WithValue(ctx, ctxKeyTransport, tr) + ctx = context.WithValue(ctx, ctxKeyProxy, "http://127.0.0.1:29992") + ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string{"8.8.8.8"}) + ctx = context.WithValue(ctx, ctxKeyDialFunc, dialFn) + + rc := getRequestContext(ctx) + if rc.Transport != tr { + t.Fatal("transport not extracted") + } + if rc.Proxy != "http://127.0.0.1:29992" { + t.Fatal("proxy not extracted") + } + if len(rc.CustomDNS) != 1 || rc.CustomDNS[0] != "8.8.8.8" { + t.Fatal("custom dns not extracted") + } + if rc.DialFn == nil { + t.Fatal("dialFn not extracted") + } +} + +// ---- 默认函数: put/delete/patch/options/trace/connect ---- + +func TestDefaultMethodsCoverage(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.Method)) + })) + defer s.Close() + + cases := []struct { + name string + fn func(string, ...RequestOpt) (*Response, error) + want string + }{ + {"PUT", Put, http.MethodPut}, + {"DELETE", Delete, http.MethodDelete}, + {"PATCH", Patch, http.MethodPatch}, + {"OPTIONS", Options, http.MethodOptions}, + {"TRACE", Trace, http.MethodTrace}, + {"CONNECT", Connect, http.MethodConnect}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.fn(s.URL) + if err != nil { + t.Fatalf("%s error: %v", tc.name, err) + } + defer resp.Close() + body, _ := resp.Body().String() + if body != tc.want { + t.Fatalf("body=%q want=%q", body, tc.want) + } + }) + } +} + +// ---- Request: SetQueries / SetTransport / SetAutoCalcContentLength / SetContentLength ---- + +func TestRequestSetQueries(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + if q.Get("a") != "1" || q.Get("b") != "2" { + t.Fatalf("query not set: %v", q) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, "GET"). + SetQueries(map[string][]string{"a": {"1"}, "b": {"2"}}) + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +func TestRequestSetTransport(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + base := &http.Transport{} + req := NewSimpleRequest(s.URL, "GET").SetTransport(base) + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +func TestRequestSetAutoCalcContentLength(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength <= 0 { + t.Fatalf("content-length not auto calculated: %d", r.ContentLength) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, "POST"). + SetBodyReader(stringsNewReaderCompat("hello-autocalc")). + SetAutoCalcContentLength(true) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +func TestRequestSetContentLength(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ContentLength != 5 { + t.Fatalf("content-length=%d", r.ContentLength) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, "POST"). + SetBody([]byte("hello")). + SetContentLength(5) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +// ---- Request: AddCustomDNS / AddCustomIP / SetDialFunc ---- + +func TestRequestAddCustomDNSAndIP(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET"). + AddCustomDNS("8.8.8.8"). + AddCustomIP("1.1.1.1") + + if req.Err() != nil { + t.Fatalf("unexpected err: %v", req.Err()) + } + if len(req.config.DNS.CustomDNS) != 1 || req.config.DNS.CustomDNS[0] != "8.8.8.8" { + t.Fatal("custom dns not added") + } + if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.1.1.1" { + t.Fatal("custom ip not added") + } +} + +func TestRequestSetDialFunc(t *testing.T) { + called := false + fn := func(ctx context.Context, network, addr string) (net.Conn, error) { + called = true + return nil, io.EOF + } + req := NewSimpleRequest("http://example.com", "GET").SetDialFunc(fn) + if req.config.Network.DialFunc == nil { + t.Fatal("dial func not set") + } + _, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1") + if !called { + t.Fatal("dial func not callable") + } +} + +// ---- Request header/cookie bulk APIs ---- + +func TestRequestSetHeadersAndAddHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-A") != "1" || r.Header.Get("X-B") != "2" || r.Header.Get("X-C") != "3" { + t.Fatalf("headers not correct: %v", r.Header) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + h := http.Header{} + h.Set("X-A", "1") + h.Set("X-B", "2") + + req := NewSimpleRequest(s.URL, "GET"). + SetHeaders(h). + AddHeaders(map[string]string{"X-C": "3"}) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +func TestRequestSetCookiesAndAddCookies(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := map[string]string{} + for _, c := range r.Cookies() { + got[c.Name] = c.Value + } + if got["a"] != "1" || got["b"] != "2" || got["c"] != "3" { + t.Fatalf("cookies=%v", got) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, "GET"). + SetCookies([]*http.Cookie{ + {Name: "a", Value: "1", Path: "/"}, + {Name: "b", Value: "2", Path: "/"}, + }). + AddCookies(map[string]string{"c": "3"}) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() +} + +// ---- Body.Close / Response.CloseWithClient ---- + +func TestBodyClose(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + defer s.Close() + + resp, err := Get(s.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + // 直接测 Body.Close + if err := resp.Body().Close(); err != nil { + t.Fatalf("Body.Close() error: %v", err) + } +} + +func TestResponseCloseWithClient(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + defer s.Close() + + resp, err := Get(s.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + if err := resp.CloseWithClient(); err != nil { + t.Fatalf("CloseWithClient() error: %v", err) + } +} + +// 小兼容函数,避免你当前文件没引 strings 包时报错(可直接替换成 strings.NewReader) +func stringsNewReaderCompat(s string) io.Reader { + return io.NopCloser(io.MultiReader(io.LimitReader(io.NopCloser(stringsReader(s)), int64(len(s))))) +} + +// 纯标准库最小 reader +type stringsReader string + +func (sr stringsReader) Read(p []byte) (int, error) { + if len(sr) == 0 { + return 0, io.EOF + } + n := copy(p, []byte(sr)) + return n, nil +} diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..b994d51 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,197 @@ +package starnet + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func BenchmarkGetRequest(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := Get(server.URL) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + resp.Body().String() + resp.Close() + } +} + +func BenchmarkGetRequestWithHeaders(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := Get(server.URL, + WithHeader("X-Custom", "value"), + WithUserAgent("BenchmarkAgent")) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + resp.Body().String() + resp.Close() + } +} + +func BenchmarkPostRequest(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + testData := []byte("test data for benchmark") + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := Post(server.URL, WithBody(testData)) + if err != nil { + b.Fatalf("Post() error: %v", err) + } + resp.Body().String() + resp.Close() + } +} + +func BenchmarkJSONRequest(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok"}`)) + })) + defer server.Close() + + type TestData struct { + Name string `json:"name"` + Value int `json:"value"` + } + + data := TestData{Name: "test", Value: 123} + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := Post(server.URL, WithJSON(data)) + if err != nil { + b.Fatalf("Post() error: %v", err) + } + var result map[string]string + resp.Body().JSON(&result) + resp.Close() + } +} + +func BenchmarkConcurrentRequests(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + resp, err := Get(server.URL) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + resp.Body().String() + resp.Close() + } + }) +} + +func BenchmarkRequestClone(b *testing.B) { + req := NewSimpleRequest("https://example.com", "GET"). + SetHeader("X-Custom", "value"). + AddQuery("key", "value") + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = req.Clone() + } +} + +func BenchmarkClientCreation(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = NewClientNoErr() + } +} + +func BenchmarkRequestCreation(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = NewSimpleRequest("https://example.com", "GET") + } +} + +func BenchmarkResponseBodyRead(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("test response data")) + })) + defer server.Close() + + // Pre-fetch response + resp, _ := Get(server.URL, WithAutoFetch(true)) + defer resp.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = resp.Body().String() + } +} + +func BenchmarkDifferentResponseSizes(b *testing.B) { + sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB + + for _, size := range sizes { + responseData := make([]byte, size) + for i := 0; i < size; i++ { + responseData[i] = 'A' + } + + b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(responseData) + })) + defer server.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + resp, err := Get(server.URL) + if err != nil { + b.Fatalf("Get() error: %v", err) + } + resp.Body().Bytes() + resp.Close() + } + }) + } +} diff --git a/body_test.go b/body_test.go new file mode 100644 index 0000000..e984649 --- /dev/null +++ b/body_test.go @@ -0,0 +1,145 @@ +package starnet + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRequestBodyBytes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Write(body) + })) + defer server.Close() + + testData := []byte("test data") + req := NewSimpleRequest(server.URL, "POST").SetBody(testData) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().Bytes() + if !bytes.Equal(body, testData) { + t.Errorf("Body = %v; want %v", body, testData) + } +} + +func TestRequestBodyString(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Write(body) + })) + defer server.Close() + + testData := "test string data" + req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != testData { + t.Errorf("Body = %v; want %v", body, testData) + } +} + +func TestRequestBodyReader(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + w.Write(body) + })) + defer server.Close() + + testData := "test reader data" + reader := strings.NewReader(testData) + req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != testData { + t.Errorf("Body = %v; want %v", body, testData) + } +} + +func TestRequestJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-Type") != ContentTypeJSON { + t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON) + } + + var data map[string]string + json.NewDecoder(r.Body).Decode(&data) + json.NewEncoder(w).Encode(data) + })) + defer server.Close() + + testData := map[string]string{ + "name": "John", + "email": "john@example.com", + } + + req := NewSimpleRequest(server.URL, "POST").SetJSON(testData) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var result map[string]string + resp.Body().JSON(&result) + + if result["name"] != testData["name"] { + t.Errorf("name = %v; want %v", result["name"], testData["name"]) + } +} + +func TestRequestFormData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + data := make(map[string]string) + for k, v := range r.Form { + if len(v) > 0 { + data[k] = v[0] + } + } + json.NewEncoder(w).Encode(data) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "POST"). + AddFormData("name", "John"). + AddFormData("email", "john@example.com") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var result map[string]string + resp.Body().JSON(&result) + + if result["name"] != "John" { + t.Errorf("name = %v; want John", result["name"]) + } + if result["email"] != "john@example.com" { + t.Errorf("email = %v; want john@example.com", result["email"]) + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..3afc593 --- /dev/null +++ b/client.go @@ -0,0 +1,324 @@ +package starnet + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "sync" + "time" +) + +// Client HTTP 客户端封装 +type Client struct { + client *http.Client + opts []RequestOpt + mu sync.RWMutex +} + +// NewClient 创建新的 Client +func NewClient(opts ...RequestOpt) (*Client, error) { + // 创建基础 Transport + baseTransport := &http.Transport{ + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + httpClient := &http.Client{ + Transport: &Transport{base: baseTransport}, + //Timeout: DefaultTimeout, + } + + // 应用选项(如果有) + if len(opts) > 0 { + // 创建临时请求以应用选项 + req, err := newRequest(context.Background(), "", http.MethodGet, opts...) + if err != nil { + return nil, wrapError(err, "create client") + } + + /* + // 如果选项中有自定义配置,应用到 httpClient + if req.config.Network.Timeout > 0 { + httpClient.Timeout = req.config.Network.Timeout + } + + */ + + // 如果有自定义 Transport + if req.config.CustomTransport && req.config.Transport != nil { + httpClient.Transport = &Transport{base: req.config.Transport} + } + } + + return &Client{ + client: httpClient, + opts: opts, + }, nil +} + +// NewClientNoErr 创建新的 Client(忽略错误) +func NewClientNoErr(opts ...RequestOpt) *Client { + client, _ := NewClient(opts...) + if client == nil { + client = &Client{ + client: &http.Client{}, + opts: opts, + } + } + return client +} + +// NewClientFromHTTP 从 http.Client 创建 Client +func NewClientFromHTTP(httpClient *http.Client) (*Client, error) { + if httpClient == nil { + return nil, ErrNilClient + } + + // 确保 Transport 是我们的自定义类型 + if httpClient.Transport == nil { + httpClient.Transport = &Transport{ + base: &http.Transport{}, + } + } else { + switch t := httpClient.Transport.(type) { + case *Transport: + // 已经是我们的类型 + if t.base == nil { + t.base = &http.Transport{} + } + case *http.Transport: + // 包装标准 Transport + httpClient.Transport = &Transport{ + base: t, + } + default: + return nil, fmt.Errorf("unsupported transport type: %T", t) + } + } + + return &Client{ + client: httpClient, + }, nil +} + +// HTTPClient 获取底层 http.Client +func (c *Client) HTTPClient() *http.Client { + return c.client +} + +// RequestOptions 获取默认选项(返回副本) +func (c *Client) RequestOptions() []RequestOpt { + c.mu.RLock() + defer c.mu.RUnlock() + + opts := make([]RequestOpt, len(c.opts)) + copy(opts, c.opts) + return opts +} + +// SetOptions 设置默认选项 +func (c *Client) SetOptions(opts ...RequestOpt) *Client { + c.mu.Lock() + c.opts = opts + c.mu.Unlock() + return c +} + +// AddOptions 追加默认选项 +func (c *Client) AddOptions(opts ...RequestOpt) *Client { + c.mu.Lock() + c.opts = append(c.opts, opts...) + c.mu.Unlock() + return c +} + +// Clone 克隆 Client(深拷贝) +func (c *Client) Clone() *Client { + c.mu.RLock() + defer c.mu.RUnlock() + + // 克隆 Transport + var transport http.RoundTripper + if c.client.Transport != nil { + switch t := c.client.Transport.(type) { + case *Transport: + transport = &Transport{ + base: t.base.Clone(), + } + case *http.Transport: + transport = t.Clone() + default: + transport = c.client.Transport + } + } + + return &Client{ + client: &http.Client{ + Transport: transport, + CheckRedirect: c.client.CheckRedirect, + Jar: c.client.Jar, + Timeout: c.client.Timeout, + }, + opts: append([]RequestOpt(nil), c.opts...), + } +} + +// SetDefaultTLSConfig 设置默认 TLS 配置 +func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client { + if transport, ok := c.client.Transport.(*Transport); ok { + transport.mu.Lock() + if transport.base.TLSClientConfig == nil { + transport.base.TLSClientConfig = &tls.Config{} + } + transport.base.TLSClientConfig = tlsConfig + transport.mu.Unlock() + } + return c +} + +// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证 +func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client { + if transport, ok := c.client.Transport.(*Transport); ok { + transport.mu.Lock() + if transport.base.TLSClientConfig == nil { + transport.base.TLSClientConfig = &tls.Config{} + } + transport.base.TLSClientConfig.InsecureSkipVerify = skip + transport.mu.Unlock() + } + return c +} + +// DisableRedirect 禁用重定向 +func (c *Client) DisableRedirect() *Client { + c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + return c +} + +// EnableRedirect 启用重定向 +func (c *Client) EnableRedirect() *Client { + c.client.CheckRedirect = nil + return c +} + +// NewRequest 创建新请求 +func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { + return c.NewRequestWithContext(context.Background(), url, method, opts...) +} + +// NewRequestWithContext 创建新请求(带 context) +func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { + // 合并 Client 级别和请求级别的选项 + c.mu.RLock() + allOpts := append(append([]RequestOpt(nil), c.opts...), opts...) + c.mu.RUnlock() + + req, err := newRequest(ctx, url, method, allOpts...) + if err != nil { + return nil, err + } + + req.client = c + req.httpClient = c.client + return req, nil +} + +// Get 发送 GET 请求 +func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodGet, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Post 发送 POST 请求 +func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodPost, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Put 发送 PUT 请求 +func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodPut, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Delete 发送 DELETE 请求 +func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodDelete, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Head 发送 HEAD 请求 +func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodHead, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Patch 发送 PATCH 请求 +func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodPatch, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Options 发送 OPTIONS 请求 +func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodOptions, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// NewSimpleRequest 创建新请求(忽略错误,支持链式调用) +func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { + return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...) +} + +// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误) +func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { + req, err := c.NewRequestWithContext(ctx, url, method, opts...) + if err != nil { + // 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致 + return &Request{ + ctx: ctx, + url: url, + method: method, + err: err, + config: &RequestConfig{ + Headers: make(http.Header), + Queries: make(map[string][]string), + Body: BodyConfig{ + FormData: make(map[string][]string), + }, + }, + client: c, + httpClient: c.client, + autoFetch: DefaultFetchRespBody, + } + } + return req +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..496d040 --- /dev/null +++ b/client_test.go @@ -0,0 +1,223 @@ +package starnet + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Fatalf("NewClient() error: %v", err) + } + if client == nil { + t.Fatal("NewClient() returned nil") + } +} + +func TestNewClientNoErr(t *testing.T) { + client := NewClientNoErr() + if client == nil { + t.Fatal("NewClientNoErr() returned nil") + } +} + +func TestNewClientFromHTTP(t *testing.T) { + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + + client, err := NewClientFromHTTP(httpClient) + if err != nil { + t.Fatalf("NewClientFromHTTP() error: %v", err) + } + if client == nil { + t.Fatal("NewClientFromHTTP() returned nil") + } + + // Test with nil client + _, err = NewClientFromHTTP(nil) + if err == nil { + t.Error("NewClientFromHTTP(nil) should return error") + } +} + +func TestClientOptions(t *testing.T) { + client := NewClientNoErr() + + // Set options + client.SetOptions(WithTimeout(5 * time.Second)) + opts := client.RequestOptions() + if len(opts) != 1 { + t.Errorf("RequestOptions() length = %v; want 1", len(opts)) + } + + // Add options + client.AddOptions(WithUserAgent("TestAgent")) + opts = client.RequestOptions() + if len(opts) != 2 { + t.Errorf("RequestOptions() length = %v; want 2", len(opts)) + } +} + +func TestClientClone(t *testing.T) { + client := NewClientNoErr(WithTimeout(5 * time.Second)) + cloned := client.Clone() + + if cloned == nil { + t.Fatal("Clone() returned nil") + } + + // 修改克隆的 client + cloned.SetOptions(WithTimeout(10 * time.Second)) + + origOpts := client.RequestOptions() + clonedOpts := cloned.RequestOptions() + + // 原 client 应该还是 1 个选项 + if len(origOpts) != 1 { + t.Errorf("Original client options = %v; want 1", len(origOpts)) + } + + // 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖) + if len(clonedOpts) != 1 { + t.Errorf("Cloned client options = %v; want 1", len(clonedOpts)) + } +} + +func TestClientHTTPMethods(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.Method)) + })) + defer server.Close() + + client := NewClientNoErr() + + tests := []struct { + name string + method func(string, ...RequestOpt) (*Response, error) + want string + }{ + {"GET", client.Get, "GET"}, + {"POST", client.Post, "POST"}, + {"PUT", client.Put, "PUT"}, + {"DELETE", client.Delete, "DELETE"}, + {"PATCH", client.Patch, "PATCH"}, + {"HEAD", client.Head, "HEAD"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := tt.method(server.URL) + if err != nil { + t.Fatalf("%s() error: %v", tt.name, err) + } + defer resp.Close() + + if tt.want != "HEAD" { + body, _ := resp.Body().String() + if body != tt.want { + t.Errorf("Body = %v; want %v", body, tt.want) + } + } + }) + } +} + +func TestClientRedirect(t *testing.T) { + redirectCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if redirectCount < 2 { + redirectCount++ + http.Redirect(w, r, "/redirected", http.StatusFound) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("final")) + })) + defer server.Close() + + // Test with redirect enabled (default) + client := NewClientNoErr() + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + resp.Close() + + if redirectCount != 2 { + t.Errorf("Redirect count = %v; want 2", redirectCount) + } + + // Test with redirect disabled + redirectCount = 0 + client.DisableRedirect() + resp2, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp2.Close() + + if resp2.StatusCode != http.StatusFound { + t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound) + } +} + +func TestClientTLSConfig(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Without skip verify (should fail with self-signed cert) + client := NewClientNoErr() + _, err := client.Get(server.URL) + if err == nil { + t.Error("Expected TLS error with self-signed cert, got nil") + } + + // With skip verify + client.SetDefaultSkipTLSVerify(true) + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get() with skip verify error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } +} + +func TestClientNewSimpleRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + client := NewClientNoErr() + + req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v")) + if req == nil { + t.Fatal("NewSimpleRequest returned nil") + } + if req.Err() != nil { + t.Fatalf("NewSimpleRequest err: %v", req.Err()) + } + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "OK" { + t.Errorf("Body = %v; want OK", body) + } +} diff --git a/concurrent_test.go b/concurrent_test.go new file mode 100644 index 0000000..575504a --- /dev/null +++ b/concurrent_test.go @@ -0,0 +1,111 @@ +package starnet + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestConcurrentRequests(t *testing.T) { + var counter int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&counter, 1) + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + client := NewClientNoErr() + concurrency := 100 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + resp, err := client.Get(server.URL) + if err != nil { + t.Errorf("Get() error: %v", err) + return + } + resp.Close() + }() + } + + wg.Wait() + + if atomic.LoadInt64(&counter) != int64(concurrency) { + t.Errorf("counter = %v; want %v", counter, concurrency) + } +} + +func TestConcurrentClientModification(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClientNoErr() + var wg sync.WaitGroup + wg.Add(200) + + // 100 goroutines reading + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + resp, err := client.Get(server.URL) + if err != nil { + t.Errorf("Get() error: %v", err) + return + } + resp.Close() + }() + } + + // 100 goroutines modifying options + for i := 0; i < 100; i++ { + go func(i int) { + defer wg.Done() + if i%2 == 0 { + client.AddOptions(WithTimeout(5 * time.Second)) + } else { + _ = client.RequestOptions() + } + }(i) + } + + wg.Wait() +} + +func TestConcurrentRequestClone(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value") + + var wg sync.WaitGroup + wg.Add(50) + + for i := 0; i < 50; i++ { + go func(i int) { + defer wg.Done() + cloned := baseReq.Clone() + // 修复:使用有效的 header 值 + cloned.SetHeader("X-Index", fmt.Sprintf("%d", i)) + resp, err := cloned.Do() + if err != nil { + t.Errorf("Do() error: %v", err) + return + } + resp.Close() + }(i) + } + + wg.Wait() +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..b88023f --- /dev/null +++ b/context.go @@ -0,0 +1,149 @@ +package starnet + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// contextKey 私有的 context key 类型(防止冲突) +type contextKey int + +const ( + ctxKeyTransport contextKey = iota + ctxKeyTLSConfig + ctxKeyProxy + ctxKeyCustomIP + ctxKeyCustomDNS + ctxKeyDialTimeout + ctxKeyTimeout + ctxKeyLookupIP + ctxKeyDialFunc +) + +// RequestContext 从 context 中提取的请求配置 +type RequestContext struct { + Transport *http.Transport + TLSConfig *tls.Config + Proxy string + CustomIP []string + CustomDNS []string + DialTimeout time.Duration + Timeout time.Duration + LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error) + DialFn func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// getRequestContext 从 context 中提取请求配置 +func getRequestContext(ctx context.Context) *RequestContext { + rc := &RequestContext{} + + if v := ctx.Value(ctxKeyTransport); v != nil { + rc.Transport, _ = v.(*http.Transport) + } + if v := ctx.Value(ctxKeyTLSConfig); v != nil { + rc.TLSConfig, _ = v.(*tls.Config) + } + if v := ctx.Value(ctxKeyProxy); v != nil { + rc.Proxy, _ = v.(string) + } + if v := ctx.Value(ctxKeyCustomIP); v != nil { + rc.CustomIP, _ = v.([]string) + } + if v := ctx.Value(ctxKeyCustomDNS); v != nil { + rc.CustomDNS, _ = v.([]string) + } + if v := ctx.Value(ctxKeyDialTimeout); v != nil { + rc.DialTimeout, _ = v.(time.Duration) + } + if v := ctx.Value(ctxKeyTimeout); v != nil { + rc.Timeout, _ = v.(time.Duration) + } + if v := ctx.Value(ctxKeyLookupIP); v != nil { + rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error)) + } + if v := ctx.Value(ctxKeyDialFunc); v != nil { + rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error)) + } + + return rc +} + +// needsDynamicTransport 判断是否需要动态 Transport +func needsDynamicTransport(rc *RequestContext) bool { + return rc.Transport != nil || + rc.TLSConfig != nil || + rc.Proxy != "" || + rc.DialFn != nil || + (rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) || + (rc.Timeout > 0 && rc.Timeout != DefaultTimeout) || + len(rc.CustomIP) > 0 || + len(rc.CustomDNS) > 0 || + rc.LookupIPFn != nil +} + +// injectRequestConfig 将请求配置注入到 context +func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context { + execCtx := ctx + + // 处理 TLS 配置 + var tlsConfig *tls.Config + + if config.TLS.Config != nil { + tlsConfig = config.TLS.Config.Clone() + if config.TLS.SkipVerify { + tlsConfig.InsecureSkipVerify = true + } + } else if config.TLS.SkipVerify { + tlsConfig = &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + InsecureSkipVerify: true, + } + } + + if tlsConfig != nil { + execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig) + } + + // 注入代理 + if config.Network.Proxy != "" { + execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy) + } + + // 注入自定义 IP + if len(config.DNS.CustomIP) > 0 { + execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP) + } + + // 注入自定义 DNS + if len(config.DNS.CustomDNS) > 0 { + execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS) + } + + // 总是注入 DialTimeout 和 Timeout(与原始代码一致) + if config.Network.DialTimeout > 0 { + execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout) + } + if config.Network.Timeout > 0 { + execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout) + } + + // 注入 DNS 解析函数 + if config.DNS.LookupFunc != nil { + execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc) + } + + // 注入 Dial 函数 + if config.Network.DialFunc != nil { + execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc) + } + + // 注入自定义 Transport + if config.CustomTransport && config.Transport != nil { + execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport) + } + + return execCtx +} diff --git a/curl.go b/curl.go deleted file mode 100644 index be7896d..0000000 --- a/curl.go +++ /dev/null @@ -1,1869 +0,0 @@ -package starnet - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/json" - "fmt" - "io" - "mime/multipart" - "net" - "net/http" - "net/url" - "os" - "strings" - "sync" - "time" -) - -type Request struct { - ctx context.Context - doCtx context.Context // 用于在请求中传递上下文信息 - uri string - method string - errInfo error - RequestOpts -} - -func (r *Request) Clone() *Request { - clonedRequest := &Request{ - ctx: r.ctx, - uri: r.uri, - method: r.method, - errInfo: r.errInfo, - RequestOpts: RequestOpts{ - headers: CloneHeader(r.headers), - cookies: CloneCookies(r.cookies), - bodyFormData: CloneStringMapSlice(r.bodyFormData), - bodyFileData: CloneFiles(r.bodyFileData), - contentLength: r.contentLength, - queries: CloneStringMapSlice(r.queries), - bodyDataBytes: CloneByteSlice(r.bodyDataBytes), - customTransport: r.customTransport, - proxy: r.proxy, - timeout: r.timeout, - dialTimeout: r.dialTimeout, - dialFn: r.dialFn, - alreadyApply: r.alreadyApply, - doRawRequest: r.doRawRequest, - skipTLSVerify: r.skipTLSVerify, - autoFetchRespBody: r.autoFetchRespBody, - customIP: CloneStringSlice(r.customIP), - alreadySetLookUpIPfn: r.alreadySetLookUpIPfn, - lookUpIPfn: r.lookUpIPfn, - customDNS: CloneStringSlice(r.customDNS), - basicAuth: r.basicAuth, - autoCalcContentLength: r.autoCalcContentLength, - }, - } - clonedRequest.rawClient = r.rawClient - - // 手动深拷贝嵌套引用类型 - if r.bodyDataReader != nil { - clonedRequest.bodyDataReader = r.bodyDataReader - } - - if r.fileUploadRecallFn != nil { - clonedRequest.fileUploadRecallFn = r.fileUploadRecallFn - } - - // 对于 tlsConfig 类型,需要手动复制 - if r.tlsConfig != nil { - clonedRequest.tlsConfig = r.tlsConfig.Clone() - } - if r.transport != nil { - clonedRequest.transport = r.transport - } - if r.doRawRequest { - clonedRequest.rawRequest = r.rawRequest - } - - if clonedRequest.rawRequest == nil { - clonedRequest.rawRequest, _ = http.NewRequestWithContext(clonedRequest.ctx, clonedRequest.method, clonedRequest.uri, nil) - } - return clonedRequest -} - -// CloneHeader 复制 http.Header -func CloneHeader(original http.Header) http.Header { - newHeader := make(http.Header) - for key, values := range original { - copiedValues := make([]string, len(values)) - copy(copiedValues, values) - newHeader[key] = copiedValues - } - return newHeader -} - -// CloneCookies 复制 []*http.Cookie -func CloneCookies(original []*http.Cookie) []*http.Cookie { - cloned := make([]*http.Cookie, len(original)) - for i, cookie := range original { - cloned[i] = &http.Cookie{ - Name: cookie.Name, - Value: cookie.Value, - Path: cookie.Path, - Domain: cookie.Domain, - Expires: cookie.Expires, - RawExpires: cookie.RawExpires, - MaxAge: cookie.MaxAge, - Secure: cookie.Secure, - HttpOnly: cookie.HttpOnly, - SameSite: cookie.SameSite, - Raw: cookie.Raw, - Unparsed: append([]string(nil), cookie.Unparsed...), - } - } - return cloned -} - -// CloneStringMapSlice 复制 map[string][]string -func CloneStringMapSlice(original map[string][]string) map[string][]string { - newMap := make(map[string][]string) - for key, values := range original { - copiedValues := make([]string, len(values)) - copy(copiedValues, values) - newMap[key] = copiedValues - } - return newMap -} - -// CloneFiles 复制 []RequestFile -func CloneFiles(original []RequestFile) []RequestFile { - newFiles := make([]RequestFile, len(original)) - copy(newFiles, original) - return newFiles -} - -// CloneByteSlice 复制 []byte -func CloneByteSlice(original []byte) []byte { - if original == nil { - return nil - } - newSlice := make([]byte, len(original)) - copy(newSlice, original) - return newSlice -} - -// CloneStringSlice 复制 []string -func CloneStringSlice(original []string) []string { - newSlice := make([]string, len(original)) - copy(newSlice, original) - return newSlice -} - -func (r *Request) Method() string { - return r.method -} - -func (r *Request) SetMethod(method string) error { - method = strings.ToUpper(method) - if !validMethod(method) { - return fmt.Errorf("invalid method: %s", method) - } - r.method = method - r.rawRequest.Method = method - return nil -} - -func (r *Request) SetMethodNoError(method string) *Request { - r.SetMethod(method) - return r -} - -func (r *Request) Uri() string { - return r.uri -} - -func (r *Request) SetUri(uri string) error { - if r.doRawRequest { - return fmt.Errorf("doRawRequest is true, cannot set uri") - } - u, err := url.Parse(uri) - if err != nil { - return fmt.Errorf("parse uri error: %s", err) - } - r.uri = uri - u.Host = removeEmptyPort(u.Host) - r.rawRequest.Host = u.Host - r.rawRequest.URL = u - if r.tlsConfig != nil { - r.tlsConfig.ServerName = u.Hostname() - } - return nil -} - -func (r *Request) SetUriNoError(uri string) *Request { - r.SetUri(uri) - return r -} - -func (r *Request) RawRequest() *http.Request { - return r.rawRequest -} - -func (r *Request) SetRawRequest(rawRequest *http.Request) *Request { - r.rawRequest = rawRequest - return r -} - -func (r *Request) RawClient() *http.Client { - return r.rawClient -} - -func (r *Request) SetRawClient(rawClient *http.Client) *Request { - r.rawClient = rawClient - return r -} - -// Do sends the HTTP request and returns the response. -func (r *Request) Do() (*Response, error) { - return Curl(r) -} - -// Get sends a GET request to the specified URI and returns the response. -func (r *Request) Get() (*Response, error) { - err := r.SetMethod("GET") - if err != nil { - return nil, err - } - return Curl(r) -} - -// Post sends a POST request with the provided data to the specified URI and returns the response. -func (r *Request) Post(data []byte) (*Response, error) { - err := r.SetMethod("POST") - if err != nil { - return nil, err - } - r.bodyDataBytes = data - r.bodyDataReader = nil - return Curl(r) -} - -type RequestOpts struct { - rawRequest *http.Request - rawClient *http.Client - transport *http.Transport - customTransport bool - - alreadyApply bool - bodyDataBytes []byte - bodyDataReader io.Reader - bodyFormData map[string][]string - bodyFileData []RequestFile - //以上优先度为 bodyDataReader> bodyDataBytes > bodyFormData > bodyFileData - fileUploadRecallFn func(filename string, upPos int64, total int64) - proxy string - timeout time.Duration - dialTimeout time.Duration - dialFn func(ctx context.Context, network, addr string) (net.Conn, error) - headers http.Header - cookies []*http.Cookie - - queries map[string][]string - //doRawRequest=true 不对request修改,直接发送 - doRawRequest bool - skipTLSVerify bool - tlsConfig *tls.Config - autoFetchRespBody bool - customIP []string - alreadySetLookUpIPfn bool - lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error) - customDNS []string - basicAuth [2]string - autoCalcContentLength bool - contentLength int64 // 人工设置 -} - -func (r *Request) ContentLength() int64 { - return r.contentLength -} - -// SetContentLength sets the Content-Length header for the request. -// This function will overwrite any existing or auto calculated Content-Length header. -// if the length is less than 0, it will not set the Content-Length header. chunked transfer encoding will be used instead. -// chunked transfer encoding may cause some servers to reject the request if they do not support it. -// Note that this function will not work if doRawRequest is true -func (r *Request) SetContentLength(contextLength int64) *Request { - r.contentLength = contextLength - return r -} - -func (r *Request) CustomTransport() bool { - return r.customTransport -} - -func (r *Request) SetCustomTransport(customTransport bool) *Request { - r.customTransport = customTransport - return r -} - -func (r *Request) FileUploadRecallFn() func(filename string, upPos int64, total int64) { - return r.fileUploadRecallFn -} - -func (r *Request) SetFileUploadRecallFn(FileUploadRecallFn func(filename string, upPos int64, total int64)) *Request { - r.fileUploadRecallFn = FileUploadRecallFn - return r -} - -func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) { - return r.dialFn -} - -// SetDialFn sets the dial function for the request. -func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) { - r.dialFn = dialFn -} - -func (r *Request) AutoCalcContentLength() bool { - return r.autoCalcContentLength -} - -// SetAutoCalcContentLength sets whether to automatically calculate the Content-Length header based on the request body. -// WARN: If set to true, the Content-Length header will be set to the length of the request body, data will be cached in memory. -// So it may cause high memory usage if the request body is large. -// If set to false, the Content-Length header will not be set,unless the request body is a byte slice or bytes.Buffer which has a specific length. -// Note that this function will not work if doRawRequest is true or the ContentLength is already set. -func (r *Request) SetAutoCalcContentLength(autoCalcContentLength bool) error { - if r.doRawRequest { - return fmt.Errorf("doRawRequest is true, cannot set autoCalcContentLength") - } - r.autoCalcContentLength = autoCalcContentLength - return nil -} - -func (r *Request) SetAutoCalcContentLengthNoError(autoCalcContentLength bool) *Request { - r.SetAutoCalcContentLength(autoCalcContentLength) - return r -} - -// BasicAuth returns the username and password provided in the request's Authorization header. -func (r *Request) BasicAuth() (string, string) { - return r.basicAuth[0], r.basicAuth[1] -} - -// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. -// Note: If doRawRequest is true, this function will nolonger work -func (r *Request) SetBasicAuth(username, password string) *Request { - r.basicAuth = [2]string{username, password} - return r -} - -func (r *Request) CustomDNS() []string { - return r.customDNS -} - -// SetCustomDNS sets the custom DNS servers for the request. -// Note: if LookUpIPfn is set, this function will not be used. -// if use custom Transport Dialer, this function will not work by default,but if the *http.Client is create by this package, it will work -func (r *Request) SetCustomDNS(customDNS []string) error { - for _, v := range customDNS { - if net.ParseIP(v) == nil { - return fmt.Errorf("invalid custom dns: %s", v) - } - } - r.customDNS = customDNS - return nil -} - -// SetCustomDNSNoError sets the custom DNS servers for the request. -// Note: if LookUpIPfn is set, this function will not be used. -// if use custom Transport Dialer, this function will not work by default,but if the *http.Client is create by this package, it will work -func (r *Request) SetCustomDNSNoError(customDNS []string) *Request { - r.SetCustomDNS(customDNS) - return r -} - -// AddCustomDNS adds custom DNS servers to the request. -// Note: if LookUpIPfn is set, this function will not be used. -// if use custom Transport Dialer, this function will not work by default,but if the *http.Client is create by this package, it will work -func (r *Request) AddCustomDNS(customDNS []string) error { - for _, v := range customDNS { - if net.ParseIP(v) == nil { - return fmt.Errorf("invalid custom dns: %s", v) - } - } - r.customDNS = customDNS - return nil -} - -// AddCustomDNSNoError adds custom DNS servers to the request. -// Note: if LookUpIPfn is set, this function will not be used. -// if use custom Transport Dialer, this function will not work by default,but if the *http.Client is create by this package, it will work -func (r *Request) AddCustomDNSNoError(customDNS []string) *Request { - r.AddCustomDNS(customDNS) - return r -} - -func (r *Request) LookUpIPfn() func(ctx context.Context, host string) ([]net.IPAddr, error) { - return r.lookUpIPfn -} - -// SetLookUpIPfn sets the function used to look up IP addresses for a given host. -// If lookUpIPfn is nil, it will use the default resolver's LookupIPAddr function. -// Note: if use custom Transport Dialer, this function will not work by default,but if the *http.Client is create by this package, it will work -// Note: if CustomHostIP is set, this function will not be used. -func (r *Request) SetLookUpIPfn(lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request { - if lookUpIPfn == nil { - r.alreadySetLookUpIPfn = false - r.lookUpIPfn = net.DefaultResolver.LookupIPAddr - return r - } - r.lookUpIPfn = lookUpIPfn - r.alreadySetLookUpIPfn = true - return r -} - -// CustomHostIP returns the custom IP addresses used for the request. -func (r *Request) CustomHostIP() []string { - return r.customIP -} - -// SetCustomHostIP sets the custom IP addresses used for the request. -// if you want to use a specific IP address for a host without DNS resolution, you can set this. -// Set nil to clear the custom IP addresses. -// Note: lookUpIPfn will not be used if customIP is set. -func (r *Request) SetCustomHostIP(customIP []string) *Request { - r.customIP = customIP - return r -} - -// AddCustomHostIP adds a custom IP address to the request. -func (r *Request) AddCustomHostIP(customIP string) *Request { - r.customIP = append(r.customIP, customIP) - return r -} - -// BodyDataBytes returns the raw body data as a byte slice. -func (r *Request) BodyDataBytes() []byte { - return r.bodyDataBytes -} - -// SetBodyDataBytes sets the raw body data for the request. -// The priority order of the data is: bodyDataReader > **bodyDataBytes** > bodyFormData > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetBodyDataBytes(bodyDataBytes []byte) *Request { - r.bodyDataBytes = bodyDataBytes - return r -} - -// BodyDataReader returns the raw body data as an io.Reader. -func (r *Request) BodyDataReader() io.Reader { - return r.bodyDataReader -} - -// SetBodyDataReader sets the raw body data for the request as an io.Reader. -// The priority order of the data is: **bodyDataReader** > bodyDataBytes > bodyFormData > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetBodyDataReader(bodyDataReader io.Reader) *Request { - r.bodyDataReader = bodyDataReader - return r -} - -// BodyFormData returns the form data as a map of string slices. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) BodyFormData() map[string][]string { - return r.bodyFormData -} - -// SetBodyFormData sets the form data for the request. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetBodyFormData(bodyFormData map[string][]string) *Request { - r.bodyFormData = bodyFormData - return r -} - -// SetFormData is an alias for SetBodyFormData. -// It allows you to set form data in the request body. -// This is useful when you want to use a more descriptive name for the function. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetFormData(bodyFormData map[string][]string) *Request { - return r.SetBodyFormData(bodyFormData) -} - -// AddFormMapData adds form data from a map to the request body. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) AddFormMapData(bodyFormData map[string]string) *Request { - for k, v := range bodyFormData { - r.bodyFormData[k] = append(r.bodyFormData[k], v) - } - return r -} - -// AddFormData adds a single key-value pair to the form data in the request body. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) AddFormData(k, v string) *Request { - r.bodyFormData[k] = append(r.bodyFormData[k], v) - return r -} - -// BodyFileData returns the file data as a slice of RequestFile. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) BodyFileData() []RequestFile { - return r.bodyFileData -} - -// SetBodyFileData sets the file data for the request. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetBodyFileData(bodyFileData []RequestFile) *Request { - r.bodyFileData = bodyFileData - return r -} - -// Proxy returns the proxy URL for the request. -func (r *Request) Proxy() string { - return r.proxy -} - -// SetProxy sets the proxy URL for the request. -func (r *Request) SetProxy(proxy string) *Request { - r.proxy = proxy - return r -} - -// Timeout returns the timeout duration for the request. -func (r *Request) Timeout() time.Duration { - return r.timeout -} - -// SetTimeout sets the timeout duration for the request. -func (r *Request) SetTimeout(timeout time.Duration) *Request { - r.timeout = timeout - return r -} - -// DialTimeout returns the dial timeout duration for the request. -func (r *Request) DialTimeout() time.Duration { - return r.dialTimeout -} - -// SetDialTimeout sets the dial timeout duration for the request. -func (r *Request) SetDialTimeout(dialTimeout time.Duration) *Request { - r.dialTimeout = dialTimeout - return r -} - -// Headers returns the request headers as an http.Header. -func (r *Request) Headers() http.Header { - return r.headers -} - -// SetHeaders sets the request headers using an http.Header. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetHeaders(headers http.Header) *Request { - r.headers = headers - return r -} - -// AddHeader adds a single header to the request. -// This function will append the header if it already exists. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) AddHeader(key, val string) *Request { - r.headers.Add(key, val) - return r -} - -// SetHeader sets a single header in the request. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetHeader(key, val string) *Request { - r.headers.Set(key, val) - return r -} - -// DeleteHeader removes a header from the request. -// if the header has multiple values, it will remove all values for that header. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) DeleteHeader(key string) *Request { - r.headers.Del(key) - return r -} - -// SetContentType sets the Content-Type header for the request. -// This function will overwrite any existing Content-Type header. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetContentType(ct string) *Request { - r.headers.Set("Content-Type", ct) - return r -} - -// SetUserAgent sets the User-Agent header for the request. -// This function will overwrite any existing User-Agent header. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetUserAgent(ua string) *Request { - r.headers.Set("User-Agent", ua) - return r -} - -// Cookies returns the request cookies as a slice of http.Cookie. -func (r *Request) Cookies() []*http.Cookie { - return r.cookies -} - -// SetCookies sets the request cookies using a slice of http.Cookie. -// you can also use SetHeader("Cookie", "cookie1=value1; cookie2=value2") to set cookies. -// Note: If doRawRequest is true, this function will not work. -func (r *Request) SetCookies(cookies []*http.Cookie) *Request { - r.cookies = cookies - return r -} - -// Transport returns the http.Transport used for the request. -func (r *Request) Transport() *http.Transport { - return r.transport -} - -// SetTransport set the http.Transport used for the request. -// Note: If doRawClient is true, this function will not work. -func (r *Request) SetTransport(transport *http.Transport) *Request { - r.transport = transport - r.customTransport = true - return r -} - -// Queries returns the request queries as a map of string slices. -func (r *Request) Queries() map[string][]string { - return r.queries -} - -// SetQueries sets the request queries using a map of string slices. -func (r *Request) SetQueries(queries map[string][]string) *Request { - r.queries = queries - return r -} - -// AddQueries adds multiple query parameters to the request. -func (r *Request) AddQueries(queries map[string]string) *Request { - for k, v := range queries { - r.queries[k] = append(r.queries[k], v) - } - return r -} - -// AddQuery adds a single query parameter to the request. -func (r *Request) AddQuery(key, value string) *Request { - r.queries[key] = append(r.queries[key], value) - return r -} - -// DelQueryKv removes a specific value from a query parameter. -func (r *Request) DelQueryKv(key, value string) *Request { - if _, ok := r.queries[key]; !ok { - return r - } - for i, v := range r.queries[key] { - if v == value { - r.queries[key] = append(r.queries[key][:i], r.queries[key][i+1:]...) - } - } - return r -} - -// DelQuery removes a query parameter from the request. -func (r *Request) DelQuery(key string) *Request { - if _, ok := r.queries[key]; !ok { - return r - } - delete(r.queries, key) - return r -} - -// DoRawRequest returns whether the request will be sent as a raw request. -func (r *Request) DoRawRequest() bool { - return r.doRawRequest -} - -// SetDoRawRequest sets whether the request will be sent as a raw request without any modifications. -// you can use this with function SetRawRequest to set a custom http.Request. -func (r *Request) SetDoRawRequest(doRawRequest bool) *Request { - r.doRawRequest = doRawRequest - return r -} - -// SkipTLSVerify returns whether the request will skip TLS verification. -func (r *Request) SkipTLSVerify() bool { - return r.skipTLSVerify -} - -// SetSkipTLSVerify Sets whether the request will skip TLS verification. -func (r *Request) SetSkipTLSVerify(skipTLSVerify bool) *Request { - r.skipTLSVerify = skipTLSVerify - return r -} - -// TlsConfig returns the TLS configuration used for the request. -func (r *Request) TlsConfig() *tls.Config { - return r.tlsConfig -} - -// SetTlsConfig sets the TLS configuration for the request. -// Note: If you use SetSkipTLSVerify function, it will automatically set the InsecureSkipVerify field to true in the tls.Config. -func (r *Request) SetTlsConfig(tlsConfig *tls.Config) *Request { - r.tlsConfig = tlsConfig - return r -} - -// AutoFetchRespBody returns whether the response body will be automatically fetched. -func (r *Request) AutoFetchRespBody() bool { - return r.autoFetchRespBody -} - -// SetAutoFetchRespBody sets whether the response body will be automatically fetched after the request is sent. -// If set to true, the response body will be read and stored in the Response object. -// if the body is too large, it may cause high memory usage. -// If set to false, you will need to manually read the response body using Response.Body() method. -func (r *Request) SetAutoFetchRespBody(autoFetchRespBody bool) *Request { - r.autoFetchRespBody = autoFetchRespBody - return r -} - -// ResetReqHeader resets the request headers to an empty http.Header. -func (r *Request) ResetReqHeader() *Request { - r.headers = make(http.Header) - return r -} - -// ResetReqCookies resets the request cookies to an empty slice. -func (r *Request) ResetReqCookies() *Request { - r.cookies = []*http.Cookie{} - return r -} - -// AddSimpleCookie add a key-value cookie to the request. -// the path will be set to "/" -func (r *Request) AddSimpleCookie(key, value string) *Request { - r.cookies = append(r.cookies, &http.Cookie{Name: key, Value: value, Path: "/"}) - return r -} - -// AddCookie adds a cookie to the request with the specified key, value, and path. -func (r *Request) AddCookie(key, value, path string) *Request { - r.cookies = append(r.cookies, &http.Cookie{Name: key, Value: value, Path: path}) - return r -} - -// AddFile adds a file to the request with the specified form name and file path. -// The file will be read and uploaded as a multipart/form-data request. -// the file type will be set to "application/octet-stream" by default. -func (r *Request) AddFile(formName, filepath string) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: stat.Name(), - FileData: nil, - FileSize: stat.Size(), - FileType: "application/octet-stream", - FilePath: filepath, - }) - return nil -} - -// AddFileStream adds a file to the request with the specified form name, filename, size, and io.Reader stream. -// The file will be read and uploaded as a multipart/form-data request. -// the file type will be set to "application/octet-stream" by default. -func (r *Request) AddFileStream(formName, filename string, size int64, stream io.Reader) error { - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: stream, - FileSize: size, - FileType: "application/octet-stream", - }) - return nil -} - -// AddFileWithName adds a file to the request with the specified form name, file path, and filename. -// you can specify a custom filename for the file being uploaded. -// The file will be read and uploaded as a multipart/form-data request. -// the file type will be set to "application/octet-stream" by default. -func (r *Request) AddFileWithName(formName, filepath, filename string) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: nil, - FileSize: stat.Size(), - FileType: "application/octet-stream", - FilePath: filepath, - }) - return nil -} - -// AddFileWithType adds a file to the request with the specified form name, file path, and file type. -// you can specify a custom file type for the file being uploaded. -// The file will be read and uploaded as a multipart/form-data request. -func (r *Request) AddFileWithType(formName, filepath, filetype string) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: stat.Name(), - FileData: nil, - FileSize: stat.Size(), - FileType: filetype, - FilePath: filepath, - }) - return nil -} - -// AddFileWithNameAndType adds a file to the request with the specified form name, file path, filename, and file type. -// you can specify a custom filename and file type for the file being uploaded. -// The file will be read and uploaded as a multipart/form-data request. -func (r *Request) AddFileWithNameAndType(formName, filepath, filename, filetype string) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: nil, - FileSize: stat.Size(), - FileType: filetype, - FilePath: filepath, - }) - return nil -} - -// AddFileStreamWithType adds a file to the request with the specified form name, filename, file type, size, and io.Reader stream. -// The file will be read and uploaded as a multipart/form-data request. -// you can specify a custom file type for the file being uploaded. -func (r *Request) AddFileStreamWithType(formName, filename, filetype string, size int64, stream io.Reader) error { - r.bodyFileData = append(r.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: stream, - FileSize: size, - FileType: filetype, - }) - return nil -} - -// AddFileNoError adds a file to the request with the specified form name and file path. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileNoError(formName, filepath string) *Request { - r.AddFile(formName, filepath) - return r -} - -// AddFileWithNameNoError adds a file to the request with the specified form name, file path, and filename. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileWithNameNoError(formName, filepath, filename string) *Request { - r.AddFileWithName(formName, filepath, filename) - return r -} - -// AddFileWithTypeNoError adds a file to the request with the specified form name, file path, and file type. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileWithTypeNoError(formName, filepath, filetype string) *Request { - r.AddFileWithType(formName, filepath, filetype) - return r -} - -// AddFileWithNameAndTypeNoError adds a file to the request with the specified form name, file path, filename, and file type. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileWithNameAndTypeNoError(formName, filepath, filename, filetype string) *Request { - r.AddFileWithNameAndType(formName, filepath, filename, filetype) - return r -} - -// AddFileStreamNoError adds a file to the request with the specified form name, filename, size, and io.Reader stream. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileStreamNoError(formName, filename string, size int64, stream io.Reader) *Request { - r.AddFileStream(formName, filename, size, stream) - return r -} - -// AddFileStreamWithTypeNoError adds a file to the request with the specified form name, filename, file type, size, and io.Reader stream. -// It will not return an error if the file cannot be added. -// this function is useful for chaining methods without error handling. -func (r *Request) AddFileStreamWithTypeNoError(formName, filename, filetype string, size int64, stream io.Reader) *Request { - r.AddFileStreamWithType(formName, filename, filetype, size, stream) - return r -} - -// HttpClient returns the http.Client used for the request. -func (r *Request) HttpClient() (*http.Client, error) { - err := applyOptions(r) - if err != nil { - return nil, err - } - return r.rawClient, nil -} - -type RequestFile struct { - FormName string - FileName string - FileData io.Reader - FileSize int64 - FileType string - FilePath string -} - -type RequestOpt func(opt *RequestOpts) error - -// WithDialTimeout sets the dial timeout for the request. -// If use custom Transport Dialer, this function will nolonger work. -func WithDialTimeout(timeout time.Duration) RequestOpt { - return func(opt *RequestOpts) error { - opt.dialTimeout = timeout - return nil - } -} - -// WithDial sets a custom dial function for the request. -// functions like WithDialTimeout will nolonger work if this function is used. -// If use custom Transport Dialer, this function will nolonger work. -func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt { - return func(opt *RequestOpts) error { - opt.dialFn = fn - return nil - } -} - -// WithTimeout sets the timeout for the request. -// If use custom Transport Dialer, this function will nolonger work. -func WithTimeout(timeout time.Duration) RequestOpt { - return func(opt *RequestOpts) error { - opt.timeout = timeout - return nil - } -} - -// WithTlsConfig sets the TLS configuration for the request. -// If use custom Transport Dialer, this function will nolonger work. -func WithTlsConfig(tlscfg *tls.Config) RequestOpt { - return func(opt *RequestOpts) error { - opt.tlsConfig = tlscfg - return nil - } -} - -// WithHeaders sets the request headers using an http.Header. -// If doRawRequest is true, this function will not work. -func WithHeader(key, val string) RequestOpt { - return func(opt *RequestOpts) error { - opt.headers.Set(key, val) - return nil - } -} - -// WithHeaderMap sets the request headers using a map of string to string. -// If doRawRequest is true, this function will not work. -func WithHeaderMap(header map[string]string) RequestOpt { - return func(opt *RequestOpts) error { - for key, val := range header { - opt.headers.Set(key, val) - } - return nil - } -} - -// WithReader sets the request body data using an io.Reader. -// The priority order of the data is: **bodyDataReader** > bodyDataBytes > bodyFormData > bodyFileData. -// If doRawRequest is true, this function will nolonger work. -func WithReader(r io.Reader) RequestOpt { - return func(opt *RequestOpts) error { - opt.bodyDataReader = r - return nil - } -} - -// WithBytes sets the request body data using a byte slice. -// The priority order of the data is: bodyDataReader > **bodyDataBytes** > bodyFormData > bodyFileData. -// If doRawRequest is true, this function will nolonger work. -func WithBytes(r []byte) RequestOpt { - return func(opt *RequestOpts) error { - opt.bodyDataBytes = r - return nil - } -} - -// WithFormData sets the request body data using a map of string slices. -// The priority order of the data is: bodyDataReader > bodyDataBytes > **bodyFormData** > bodyFileData. -// If doRawRequest is true, this function will nolonger work. -func WithFormData(data map[string][]string) RequestOpt { - return func(opt *RequestOpts) error { - opt.bodyFormData = data - return nil - } -} - -// WithFileDatas sets the request body file data using a slice of RequestFile. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// If doRawRequest is true, this function will nolonger work. -func WithFileDatas(data []RequestFile) RequestOpt { - return func(opt *RequestOpts) error { - opt.bodyFileData = data - return nil - } -} - -// WithFileData sets the request body file data using a single RequestFile. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// If doRawRequest is true, this function will nolonger work. -func WithFileData(data RequestFile) RequestOpt { - return func(opt *RequestOpts) error { - opt.bodyFileData = append(opt.bodyFileData, data) - return nil - } -} - -// WithAddFile adds a file to the request with the specified form name and file path. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// The file will be read and uploaded as a multipart/form-data request. -// the file type will be set to "application/octet-stream" by default. -// If doRawRequest is true, this function will nolonger work. -func WithAddFile(formName, filepath string) RequestOpt { - return func(opt *RequestOpts) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - opt.bodyFileData = append(opt.bodyFileData, RequestFile{ - FormName: formName, - FileName: stat.Name(), - FileData: nil, - FileSize: stat.Size(), - FileType: "application/octet-stream", - FilePath: filepath, - }) - return nil - } -} - -// WithAddFileWithName adds a file to the request with the specified form name, file path, and filename. -// you can specify a custom filename for the file being uploaded. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// The file will be read and uploaded as a multipart/form-data request. -// If doRawRequest is true, this function will nolonger work. -func WithAddFileWithName(formName, filepath, filename string) RequestOpt { - return func(opt *RequestOpts) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - opt.bodyFileData = append(opt.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: nil, - FileSize: stat.Size(), - FileType: "application/octet-stream", - }) - return nil - } -} - -// WithAddFileWithType adds a file to the request with the specified form name, file path, and file type. -// you can specify a custom file type for the file being uploaded. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// The file will be read and uploaded as a multipart/form-data request. -func WithAddFileWithType(formName, filepath, filetype string) RequestOpt { - return func(opt *RequestOpts) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - opt.bodyFileData = append(opt.bodyFileData, RequestFile{ - FormName: formName, - FileName: stat.Name(), - FileData: nil, - FileSize: stat.Size(), - FileType: filetype, - }) - return nil - } -} - -// WithAddFileWithNameAndType adds a file to the request with the specified form name, file path, filename, and file type. -// you can specify a custom filename and file type for the file being uploaded. -// The priority order of the data is: bodyDataReader > bodyDataBytes > bodyFormData > **bodyFileData**. -// The file will be read and uploaded as a multipart/form-data request. -func WithAddFileWithNameAndType(formName, filepath, filename, filetype string) RequestOpt { - return func(opt *RequestOpts) error { - stat, err := os.Stat(filepath) - if err != nil { - return err - } - opt.bodyFileData = append(opt.bodyFileData, RequestFile{ - FormName: formName, - FileName: filename, - FileData: nil, - FileSize: stat.Size(), - FileType: filetype, - }) - return nil - } -} - -// WithFetchRespBody sets whether the response body will be automatically fetched after the request is sent. -// If set to true, the response body will be read and stored in the Response object. -// If the body is too large, it may cause high memory usage. -// If set to false, you will need to manually read the response body using Response.Body() method. -func WithFetchRespBody(fetch bool) RequestOpt { - return func(opt *RequestOpts) error { - opt.autoFetchRespBody = fetch - return nil - } -} - -// WithCookies sets the request cookies using a slice of http.Cookie. -func WithCookies(ck []*http.Cookie) RequestOpt { - return func(opt *RequestOpts) error { - opt.cookies = ck - return nil - } -} - -// WithCookie adds a cookie to the request with the specified key, value, and path. -func WithCookie(key, val, path string) RequestOpt { - return func(opt *RequestOpts) error { - opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: path}) - return nil - } -} - -// WithSimpleCookie adds a simple cookie to the request with the specified key and value. -func WithSimpleCookie(key, val string) RequestOpt { - return func(opt *RequestOpts) error { - opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: "/"}) - return nil - } -} - -// WithCookieMap sets the request cookies using a map of string to string. -func WithCookieMap(header map[string]string, path string) RequestOpt { - return func(opt *RequestOpts) error { - for key, val := range header { - opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: path}) - } - return nil - } -} - -// WithQueries sets the request queries using a map of string slices. -func WithQueries(queries map[string][]string) RequestOpt { - return func(opt *RequestOpts) error { - opt.queries = queries - return nil - } -} - -// WithAddQueries adds multiple query parameters to the request using a map of string to string slices. -// if the key already exists, it will append the value to the existing slice. -func WithAddQueries(queries map[string]string) RequestOpt { - return func(opt *RequestOpts) error { - for k, v := range queries { - opt.queries[k] = append(opt.queries[k], v) - } - return nil - } -} - -// WithAddQuery adds a single query parameter to the request. -func WithAddQuery(key, val string) RequestOpt { - return func(opt *RequestOpts) error { - opt.queries[key] = append(opt.queries[key], val) - return nil - } -} - -// WithProxy sets the proxy URL for the request. -func WithProxy(proxy string) RequestOpt { - return func(opt *RequestOpts) error { - opt.proxy = proxy - return nil - } -} - -// WithProcess sets a callback function to process the file upload progress. -// The callback function will be called with the file name, uploaded bytes and total bytes. -// example: -// -// WithProcess(func(name string, uploaded int64, total int64) { -// fmt.Printf("Uploading %s: %d/%d bytes\n", name, uploaded, total) -// }) -func WithProcess(fn func(string, int64, int64)) RequestOpt { - return func(opt *RequestOpts) error { - opt.fileUploadRecallFn = fn - return nil - } -} - -// WithContentType sets the Content-Type header for the request. -// This function will overwrite any existing Content-Type header. -func WithContentType(ct string) RequestOpt { - return func(opt *RequestOpts) error { - opt.headers.Set("Content-Type", ct) - return nil - } -} - -// WithUserAgent sets the User-Agent header for the request. -// This function will overwrite any existing User-Agent header. -func WithUserAgent(ua string) RequestOpt { - return func(opt *RequestOpts) error { - opt.headers.Set("User-Agent", ua) - return nil - } -} - -// WithSkipTLSVerify sets whether the request will skip TLS verification. -// If set to true, the request will not verify the server's TLS certificate. -func WithSkipTLSVerify(skip bool) RequestOpt { - return func(opt *RequestOpts) error { - opt.skipTLSVerify = skip - return nil - } -} - -/* -// WithDisableRedirect sets whether the request will disable HTTP redirects. -// If set to true, the request will not follow redirects automatically. -// For example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect. -// You will get the original response with the redirect status code and Location header. -func WithDisableRedirect(disable bool) RequestOpt { - return func(opt *RequestOpts) error { - opt.disableRedirect = disable - return nil - } -} -*/ - -// WithDoRawRequest sets whether the request will be sent as a raw request without any modifications. -// You can use this with function SetRawRequest to set a custom http.Request. -// If set to true, the request will not apply any modifications to the request headers, body, or other settings. -// If set to false, the request will apply the modifications as usual. -func WithDoRawRequest(doRawRequest bool) RequestOpt { - return func(opt *RequestOpts) error { - opt.doRawRequest = doRawRequest - return nil - } -} - -// WithTransport sets the http.Transport used for the request. -func WithTransport(hs *http.Transport) RequestOpt { - return func(opt *RequestOpts) error { - opt.transport = hs - opt.customTransport = true - return nil - } -} - -// WithRawRequest sets a custom http.Request for the request. -func WithRawRequest(req *http.Request) RequestOpt { - return func(opt *RequestOpts) error { - opt.rawRequest = req - return nil - } -} - -// WithRawClient sets a custom http.Client for the request. -func WithRawClient(hc *http.Client) RequestOpt { - return func(opt *RequestOpts) error { - opt.rawClient = hc - return nil - } -} - -// WithCustomHostIP sets custom IPs for the host. -// it means that the request will use the specified IPs to resolve the host instead of using DNS. -// Note: LookUpIPfn will be ignored if this function is used. -func WithCustomHostIP(ip []string) RequestOpt { - return func(opt *RequestOpts) error { - if len(ip) == 0 { - return nil - } - for _, v := range ip { - if net.ParseIP(v) == nil { - return fmt.Errorf("invalid custom ip: %s", v) - } - } - opt.customIP = ip - return nil - } -} - -// WithAddCustomHostIP adds a custom IP to the request. -func WithAddCustomHostIP(ip string) RequestOpt { - return func(opt *RequestOpts) error { - if net.ParseIP(ip) == nil { - return fmt.Errorf("invalid custom ip: %s", ip) - } - opt.customIP = append(opt.customIP, ip) - return nil - } -} - -// WithLookUpFn sets a custom function to look up IP addresses for the host. -// If set to nil, it will use the default net.Resolver.LookupIPAddr function. -// Note: If customDNS is set, this function will not be used. -func WithLookUpFn(lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { - return func(opt *RequestOpts) error { - if lookUpIPfn == nil { - opt.alreadySetLookUpIPfn = false - opt.lookUpIPfn = net.DefaultResolver.LookupIPAddr - return nil - } - opt.lookUpIPfn = lookUpIPfn - opt.alreadySetLookUpIPfn = true - return nil - } -} - -// WithCustomDNS will use custom dns to resolve the host -// Note: if LookUpIPfn is set, this function will not be used -func WithCustomDNS(customDNS []string) RequestOpt { - return func(opt *RequestOpts) error { - for _, v := range customDNS { - if net.ParseIP(v) == nil { - return fmt.Errorf("invalid custom dns: %s", v) - } - } - opt.customDNS = customDNS - return nil - } -} - -// WithAddCustomDNS will use a custom dns to resolve the host -// Note: if LookUpIPfn is set, this function will not be used -func WithAddCustomDNS(customDNS string) RequestOpt { - return func(opt *RequestOpts) error { - if net.ParseIP(customDNS) == nil { - return fmt.Errorf("invalid custom dns: %s", customDNS) - } - opt.customDNS = append(opt.customDNS, customDNS) - return nil - } -} - -// WithAutoCalcContentLength sets whether to automatically calculate the Content-Length header based on the request body. -// WARN: If set to true, the Content-Length header will be set to the length of the request body, data will be cached in memory. -// So it may cause high memory usage if the request body is large. -// If set to false, the Content-Length header will not be set,unless the request body is a byte slice or bytes.Buffer which has a specific length. -// Note that this function will not work if doRawRequest is true or ContentLength already set -func WithAutoCalcContentLength(autoCalcContentLength bool) RequestOpt { - return func(opt *RequestOpts) error { - opt.autoCalcContentLength = autoCalcContentLength - return nil - } -} - -// WithContentLength sets the Content-Length for the request. -// This function will overwrite any existing or auto calculated Content-Length header. -// if the length is less than 0, it will not set the Content-Length header. chunked transfer encoding will be used instead. -// chunked transfer encoding may cause some servers to reject the request if they do not support it. -// Note that this function will not work if doRawRequest is true -func WithContentLength(length int64) RequestOpt { - return func(opt *RequestOpts) error { - opt.contentLength = length - return nil - } -} - -type Response struct { - *http.Response - req Request - data *Body - rawClient *http.Client -} - -type Body struct { - full []byte - raw io.ReadCloser - isFull bool - sync.Mutex -} - -func (b *Body) readAll() { - b.Lock() - defer b.Unlock() - if !b.isFull { - if b.raw == nil { - b.isFull = true - return - } - b.full, _ = io.ReadAll(b.raw) - b.isFull = true - b.raw.Close() - } -} - -// String will read the body and return it as a string. -// if the body is too large, it may cause high memory usage. -func (b *Body) String() string { - b.readAll() - return string(b.full) -} - -// Bytes will read the body and return it as a byte slice. -// if the body is too large, it may cause high memory usage. -func (b *Body) Bytes() []byte { - b.readAll() - return b.full -} - -// Unmarshal will read the body and unmarshal it into the given interface using json.Unmarshal -// if the body is too large, it may cause high memory usage. -func (b *Body) Unmarshal(u interface{}) error { - b.readAll() - return json.Unmarshal(b.full, u) -} - -// Reader returns a reader for the body -// if this function is called, other functions like String, Bytes, Unmarshal not work -func (b *Body) Reader() io.ReadCloser { - b.Lock() - defer b.Unlock() - if b.isFull { - return io.NopCloser(bytes.NewReader(b.full)) - } - b.isFull = true - return b.raw -} - -// Close closes the body reader. -func (b *Body) Close() error { - return b.raw.Close() -} - -// GetRequest returns the original Request object associated with the Response. -func (r *Response) GetRequest() Request { - return r.req -} - -// Body returns the Body object associated with the Response. -func (r *Response) Body() *Body { - return r.data -} - -// Close closes the response body and releases any resources associated with it. -func (r *Response) Close() error { - if r != nil && r.data != nil && r.data.raw != nil { - return r.Response.Body.Close() - } - return nil -} - -// CloseAll closes the response body and releases any resources associated with it. -// It also closes all idle connections in the http.Client if it is not nil. -func (r *Response) CloseAll() error { - if r.rawClient != nil { - r.rawClient.CloseIdleConnections() - } - return r.Close() -} - -// HttpClient returns the http.Client used for the request. -func (r *Response) HttpClient() *http.Client { - return r.rawClient -} - -// Curl sends the HTTP request and returns the response. -func Curl(r *Request) (*Response, error) { - r.errInfo = nil - err := applyOptions(r) - if err != nil { - return nil, fmt.Errorf("apply options error: %s", err) - } - r.rawRequest = r.rawRequest.WithContext(r.doCtx) - resp, err := r.rawClient.Do(r.rawRequest) - var res = Response{ - Response: resp, - req: *r, - data: new(Body), - rawClient: r.rawClient, - } - if err != nil { - res.Response = &http.Response{} - return &res, fmt.Errorf("do request error: %s", err) - } - res.data.raw = resp.Body - if r.autoFetchRespBody { - res.data.full, _ = io.ReadAll(resp.Body) - res.data.isFull = true - resp.Body.Close() - } - return &res, r.errInfo -} - -// NewReq creates a new Request with the specified URI and default method "GET". -func NewReq(uri string, opts ...RequestOpt) *Request { - return NewSimpleRequest(uri, "GET", opts...) -} - -// NewReqWithContext creates a new Request with the specified URI and default method "GET" using the provided context. -func NewReqWithContext(ctx context.Context, uri string, opts ...RequestOpt) *Request { - return NewSimpleRequestWithContext(ctx, uri, "GET", opts...) -} - -// NewSimpleRequest creates a new Request with the specified URI and method. -func NewSimpleRequest(uri string, method string, opts ...RequestOpt) *Request { - r, _ := newRequest(context.Background(), uri, method, opts...) - return r -} - -// NewRequest creates a new Request with the specified URI and method. -func NewRequest(uri string, method string, opts ...RequestOpt) (*Request, error) { - return newRequest(context.Background(), uri, method, opts...) -} - -// NewSimpleRequestWithContext creates a new Request with the specified URI and method using the provided context. -func NewSimpleRequestWithContext(ctx context.Context, uri string, method string, opts ...RequestOpt) *Request { - r, _ := newRequest(ctx, uri, method, opts...) - return r -} - -// NewRequestWithContext creates a new Request with the specified URI and method using the provided context. -func NewRequestWithContext(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) { - return newRequest(ctx, uri, method, opts...) -} - -func newRequest(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) { - var req *http.Request - var err error - if method == "" { - method = "GET" - } - method = strings.ToUpper(method) - req, err = http.NewRequestWithContext(ctx, method, uri, nil) - if err != nil { - return nil, err - } - var r = &Request{ - ctx: ctx, - uri: uri, - method: method, - RequestOpts: RequestOpts{ - rawRequest: req, - rawClient: nil, - timeout: DefaultTimeout, - dialTimeout: DefaultDialTimeout, - autoFetchRespBody: DefaultFetchRespBody, - lookUpIPfn: net.DefaultResolver.LookupIPAddr, - bodyFormData: make(map[string][]string), - queries: make(map[string][]string), - }, - } - - r.headers = make(http.Header) - if strings.ToUpper(method) == "POST" { - r.headers.Set("Content-Type", HEADER_FORM_URLENCODE) - } - r.headers.Set("User-Agent", "B612 Starnet / 0.3.0") - for _, v := range opts { - if v != nil { - err = v(&r.RequestOpts) - if err != nil { - return nil, err - } - } - } - if r.rawClient == nil { - r.rawClient = new(http.Client) - } - if r.tlsConfig == nil { - r.tlsConfig = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - } - r.tlsConfig.InsecureSkipVerify = r.skipTLSVerify - if r.transport == nil { - r.transport = &http.Transport{ - ForceAttemptHTTP2: true, - DialContext: DefaultDialFunc, - DialTLSContext: DefaultDialTlsFunc, - Proxy: DefaultProxyURL(), - } - } - return r, nil -} - -func applyDataReader(r *Request) error { - // 优先度为:bodyDataReader > bodyDataBytes > bodyFormData > bodyFileData - r.rawRequest.ContentLength = 0 - if r.bodyDataReader != nil { - switch v := r.bodyDataReader.(type) { - case *bytes.Buffer: - r.rawRequest.ContentLength = int64(v.Len()) - case *bytes.Reader: - r.rawRequest.ContentLength = int64(v.Len()) - case *strings.Reader: - r.rawRequest.ContentLength = int64(v.Len()) - } - r.rawRequest.Body = io.NopCloser(r.bodyDataReader) - return nil - } - if len(r.bodyDataBytes) != 0 { - r.rawRequest.ContentLength = int64(len(r.bodyDataBytes)) - r.rawRequest.Body = io.NopCloser(bytes.NewReader(r.bodyDataBytes)) - return nil - } - if len(r.bodyFormData) != 0 && len(r.bodyFileData) == 0 { - var body = url.Values{} - for k, v := range r.bodyFormData { - for _, vv := range v { - body.Add(k, vv) - } - } - r.rawRequest.ContentLength = int64(len(body.Encode())) - r.rawRequest.Body = io.NopCloser(strings.NewReader(body.Encode())) - return nil - } - if len(r.bodyFileData) != 0 { - var pr, pw = io.Pipe() - var w = multipart.NewWriter(pw) - r.rawRequest.Header.Set("Content-Type", w.FormDataContentType()) - go func() { - defer pw.Close() // ensure pipe writer is closed - if len(r.bodyFormData) != 0 { - for k, v := range r.bodyFormData { - for _, vv := range v { - if err := w.WriteField(k, vv); err != nil { - r.errInfo = err - pw.CloseWithError(err) // close pipe with error - return - } - } - } - } - for idx, v := range r.bodyFileData { - var fw, err = w.CreateFormFile(v.FormName, v.FileName) - if err != nil { - r.errInfo = err - pw.CloseWithError(err) // close pipe with error - return - } - if v.FileData == nil { - if v.FilePath != "" { - tmpFile, err := os.Open(v.FilePath) - if err != nil { - r.errInfo = err - pw.CloseWithError(err) // close pipe with error - return - } - defer tmpFile.Close() - v.FileData = tmpFile - } else { - r.errInfo = fmt.Errorf("io reader is nil") - pw.CloseWithError(fmt.Errorf("io reader is nil")) // close pipe with error - return - } - } - if _, err := copyWithContext(r.ctx, r.fileUploadRecallFn, v.FileName, v.FileSize, fw, v.FileData); err != nil { - r.errInfo = err - r.bodyFileData[idx] = v - pw.CloseWithError(err) // close pipe with error - return - } - } - - if err := w.Close(); err != nil { - pw.CloseWithError(err) // close pipe with error if writer close fails - } - }() - - r.rawRequest.Body = pr - return nil - } - return nil -} - -func applyOptions(r *Request) error { - defer func() { - r.alreadyApply = true - }() - var req = r.rawRequest - if !r.doRawRequest { - if r.queries != nil { - sid := req.URL.Query() - for k, v := range r.queries { - for _, vv := range v { - sid.Add(k, vv) - } - } - req.URL.RawQuery = sid.Encode() - } - for k, v := range r.headers { - for _, vv := range v { - req.Header.Add(k, vv) - } - } - if len(r.cookies) != 0 { - for _, v := range r.cookies { - req.AddCookie(v) - } - } - if r.basicAuth[0] != "" || r.basicAuth[1] != "" { - req.SetBasicAuth(r.basicAuth[0], r.basicAuth[1]) - } - err := applyDataReader(r) - if err != nil { - return fmt.Errorf("apply data reader error: %s", err) - } - if r.autoCalcContentLength { - if req.Body != nil { - data, err := io.ReadAll(req.Body) - if err != nil { - return fmt.Errorf("read data error: %s", err) - } - req.ContentLength = int64(len(data)) - req.Body = io.NopCloser(bytes.NewBuffer(data)) - } - } - if r.contentLength > 0 { - req.ContentLength = r.contentLength - } else if r.contentLength < 0 { - //force use chunked transfer encoding - req.ContentLength = 0 - } - } - if !r.alreadySetLookUpIPfn && len(r.customDNS) > 0 { - resolver := net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { - for _, addr := range r.customDNS { - if conn, err = net.Dial("udp", addr+":53"); err != nil { - continue - } else { - return conn, nil - } - } - return - }, - } - r.lookUpIPfn = resolver.LookupIPAddr - } - if r.tlsConfig == nil { - r.tlsConfig = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - } - if r.tlsConfig.ServerName == "" { - r.tlsConfig.ServerName = r.rawRequest.URL.Hostname() - } - - r.tlsConfig.InsecureSkipVerify = r.skipTLSVerify - - if r.rawClient.Transport == nil { - r.rawClient.Transport = &Transport{base: r.transport} - } - - r.doCtx = context.WithValue(context.WithValue(r.ctx, "dialTimeout", r.dialTimeout), "timeout", r.timeout) - r.doCtx = context.WithValue(r.doCtx, "lookUpIP", r.lookUpIPfn) - if r.customIP != nil && len(r.customIP) > 0 { - r.doCtx = context.WithValue(r.doCtx, "customIP", r.customIP) - } - r.doCtx = context.WithValue(r.doCtx, "tlsConfig", r.tlsConfig) - if r.proxy != "" { - r.doCtx = context.WithValue(r.doCtx, "proxy", r.proxy) - } - if r.dialFn != nil { - r.doCtx = context.WithValue(r.doCtx, "dialFn", r.dialFn) - } - if r.customTransport { - r.doCtx = context.WithValue(r.doCtx, "custom", r.transport) - } - return nil -} - -func copyWithContext(ctx context.Context, recall func(string, int64, int64), filename string, total int64, dst io.Writer, src io.Reader) (written int64, err error) { - pr, pw := io.Pipe() - defer pr.Close() - - go func() { - defer pw.Close() - _, err := io.Copy(pw, src) - if err != nil { - pw.CloseWithError(err) - } - }() - var count int64 - buf := make([]byte, 4096) - for { - select { - case <-ctx.Done(): - return written, ctx.Err() - default: - nr, err := pr.Read(buf) - if err != nil { - if err == io.EOF { - if recall != nil { - go recall(filename, count, total) - } - return written, nil - } - return written, err - } - count += int64(nr) - if recall != nil { - go recall(filename, count, total) - } - nw, err := dst.Write(buf[:nr]) - if err != nil { - return written, err - } - if nr != nw { - return written, io.ErrShortWrite - } - written += int64(nr) - } - } -} - -func NewReqWithClient(client Client, uri string, opts ...RequestOpt) *Request { - return NewSimpleRequestWithClient(client, uri, "GET", opts...) -} - -func NewReqWithContextWithClient(ctx context.Context, client Client, uri string, opts ...RequestOpt) *Request { - return NewSimpleRequestWithContextWithClient(ctx, client, uri, "GET", opts...) -} - -func NewSimpleRequestWithClient(client Client, uri string, method string, opts ...RequestOpt) *Request { - r, _ := NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...) - return r -} - -func NewRequestWithClient(client Client, uri string, method string, opts ...RequestOpt) (*Request, error) { - return NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...) -} - -func NewSimpleRequestWithContextWithClient(ctx context.Context, client Client, uri string, method string, opts ...RequestOpt) *Request { - r, _ := NewRequestWithContextWithClient(ctx, client, uri, method, opts...) - return r -} - -func NewRequestWithContextWithClient(ctx context.Context, client Client, uri string, method string, opts ...RequestOpt) (*Request, error) { - if client.opts == nil { - client.opts = []RequestOpt{} - } - cOpts := append(client.opts, opts...) - req, err := newRequest(ctx, uri, method, cOpts...) - if err != nil { - return nil, err - } - req.rawClient = client.Client - return req, err -} diff --git a/curl_default.go b/curl_default.go deleted file mode 100644 index 55a48ae..0000000 --- a/curl_default.go +++ /dev/null @@ -1,198 +0,0 @@ -package starnet - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "net/url" - "strings" - "time" -) - -const ( - HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded` - HEADER_FORM_DATA = `multipart/form-data` - HEADER_JSON = `application/json` - HEADER_PLAIN = `text/plain` -) - -var ( - DefaultDialTimeout = 5 * time.Second - DefaultTimeout = 10 * time.Second - DefaultFetchRespBody = false - DefaultHttpClient = NewHttpClientNoErr() -) - -func UrlEncodeRaw(str string) string { - strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1) - return strs -} - -func UrlEncode(str string) string { - return url.QueryEscape(str) -} - -func UrlDecode(str string) (string, error) { - return url.QueryUnescape(str) -} - -func BuildQuery(queryData map[string]string) string { - query := url.Values{} - for k, v := range queryData { - query.Add(k, v) - } - return query.Encode() -} - -// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string, -// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests, -// where the server expects the request body to be URL-encoded form data. -// -// Parameters: -// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request. -// -// Returns: -// A byte slice representing the URL-encoded form data. -func BuildPostForm(queryMap map[string]string) []byte { - return []byte(BuildQuery(queryMap)) -} - -func Get(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "GET", opts...).Do() -} - -func Post(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "POST", opts...).Do() -} - -func Options(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "OPTIONS", opts...).Do() -} - -func Put(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PUT", opts...).Do() -} - -func Delete(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "DELETE", opts...).Do() -} - -func Head(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "HEAD", opts...).Do() -} - -func Patch(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PATCH", opts...).Do() -} - -func Trace(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "TRACE", opts...).Do() -} - -func Connect(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequestWithClient(DefaultHttpClient, uri, "CONNECT", opts...).Do() -} - -func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse -} - -func DefaultDialFunc(ctx context.Context, netType, addr string) (net.Conn, error) { - var lastErr error - var addrs []string - if dialFn, ok := ctx.Value("dialFunc").(func(context.Context, string, string) (net.Conn, error)); ok { - if dialFn != nil { - return dialFn(ctx, netType, addr) - } - } - customIP, ok := ctx.Value("customIP").([]string) - if !ok { - customIP = nil - } - dialTimeout, ok := ctx.Value("dialTimeout").(time.Duration) - if !ok { - dialTimeout = DefaultDialTimeout - } - timeout, ok := ctx.Value("timeout").(time.Duration) - if !ok { - timeout = DefaultTimeout - } - lookUpIPfn, ok := ctx.Value("lookUpIP").(func(context.Context, string) ([]net.IPAddr, error)) - if !ok { - lookUpIPfn = net.DefaultResolver.LookupIPAddr - } - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - proxy, ok := ctx.Value("proxy").(string) - if !ok { - proxy = "" - } - if proxy == "" && len(customIP) > 0 { - for _, v := range customIP { - ipAddr := net.ParseIP(v) - if ipAddr == nil { - return nil, fmt.Errorf("invalid custom ip: %s", customIP) - } - tmpAddr := net.JoinHostPort(v, port) - addrs = append(addrs, tmpAddr) - } - } else { - ipLists, err := lookUpIPfn(ctx, host) - if err != nil { - return nil, err - } - for _, v := range ipLists { - tmpAddr := net.JoinHostPort(v.String(), port) - addrs = append(addrs, tmpAddr) - } - } - for _, addr := range addrs { - c, err := net.DialTimeout(netType, addr, dialTimeout) - if err != nil { - lastErr = err - continue - } - if timeout != 0 { - err = c.SetDeadline(time.Now().Add(timeout)) - } - return c, nil - } - return nil, lastErr -} - -func DefaultDialTlsFunc(ctx context.Context, netType, addr string) (net.Conn, error) { - conn, err := DefaultDialFunc(ctx, netType, addr) - if err != nil { - return nil, err - } - tlsConfig, ok := ctx.Value("tlsConfig").(*tls.Config) - if !ok || tlsConfig == nil { - return nil, fmt.Errorf("tlsConfig is not set in context") - } - tlsConn := tls.Client(conn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - return nil, fmt.Errorf("tls handshake failed: %w", err) - } - return tlsConn, nil -} - -func DefaultProxyURL() func(*http.Request) (*url.URL, error) { - return func(req *http.Request) (*url.URL, error) { - if req == nil { - return nil, fmt.Errorf("request is nil") - } - proxyURL, ok := req.Context().Value("proxy").(string) - if !ok || proxyURL == "" { - return nil, nil - } - parsedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("failed to parse proxy URL: %w", err) - } - return parsedURL, nil - } -} diff --git a/curl_test.go b/curl_test.go deleted file mode 100644 index 8e8dc3e..0000000 --- a/curl_test.go +++ /dev/null @@ -1,728 +0,0 @@ -package starnet - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func TestUrlEncodeRaw(t *testing.T) { - input := "hello world!@#$%^&*()_+-=~`" - expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60" - result := UrlEncodeRaw(input) - if result != expected { - t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected) - } -} - -func TestUrlEncode(t *testing.T) { - input := "hello world!@#$%^&*()_+-=~`" - expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60` - result := UrlEncode(input) - if result != expected { - t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected) - } -} - -func TestUrlDecode(t *testing.T) { - input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60" - expected := "hello world!@#$%^&*()_+-=~`" - result, err := UrlDecode(input) - if err != nil { - t.Errorf("UrlDecode(%q) returned error: %v", input, err) - } - if result != expected { - t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected) - } - - // Test for error case - invalidInput := "%zz" - _, err = UrlDecode(invalidInput) - if err == nil { - t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput) - } -} - -func TestBuildPostForm_WithValidInput(t *testing.T) { - input := map[string]string{ - "key1": "value1", - "key2": "value2", - } - - expected := []byte("key1=value1&key2=value2") - - result := BuildPostForm(input) - - if string(result) != string(expected) { - t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) - } -} - -func TestBuildPostForm_WithEmptyInput(t *testing.T) { - input := map[string]string{} - - expected := []byte("") - - result := BuildPostForm(input) - - if string(result) != string(expected) { - t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) - } -} - -func TestBuildPostForm_WithNilInput(t *testing.T) { - var input map[string]string - - expected := []byte("") - - result := BuildPostForm(input) - - if string(result) != string(expected) { - t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) - } -} - -func TestGetRequest(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Get(server.URL, WithSkipTLSVerify(true), WithHeader("hello", "world"), WithUserAgent("hello world")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestPostRequest(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { - t.Errorf("Expected 'POST', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Post(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestOptionsRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodOptions { - t.Errorf("Expected 'OPTIONS', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Options(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestPutRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPut { - t.Errorf("Expected 'PUT', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Put(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestDeleteRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodDelete { - t.Errorf("Expected 'DELETE', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Delete(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestHeadRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodHead { - t.Errorf("Expected 'HEAD', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Head(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body == "OK" { - t.Errorf("Expected , got %v", body) - } -} - -func TestPatchRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPatch { - t.Errorf("Expected 'PATCH', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Patch(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestTraceRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodTrace { - t.Errorf("Expected 'TRACE', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Trace(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestConnectRequestWithValidInput(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodConnect { - t.Errorf("Expected 'CONNECT', got %v", req.Method) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Connect(server.URL) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} -func TestMethodReturnsCorrectValue(t *testing.T) { - req := NewReq("https://example.com") - req.SetMethodNoError("GET") - if req.Method() != "GET" { - t.Errorf("Expected 'GET', got %v", req.Method()) - } -} - -func TestSetMethodHandlesInvalidInput(t *testing.T) { - req := NewReq("https://example.com") - err := req.SetMethod("我是谁") - if err == nil { - t.Errorf("Expected error, got nil") - } -} - -func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) { - req := NewReq("https://example.com") - req.SetMethodNoError("POST") - if req.Method() != "POST" { - t.Errorf("Expected 'POST', got %v", req.Method()) - } -} - -func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) { - req := NewReq("https://example.com") - req.SetMethodNoError("你是谁") - if req.Method() != "GET" { - t.Errorf("Expected '', got %v", req.Method()) - } -} - -func TestUriReturnsCorrectValue(t *testing.T) { - req := NewReq("https://example.com") - if req.Uri() != "https://example.com" { - t.Errorf("Expected 'https://example.com', got %v", req.Uri()) - } -} - -func TestSetUriHandlesValidInput(t *testing.T) { - req := NewReq("https://example.com") - err := req.SetUri("https://newexample.com") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if req.Uri() != "https://newexample.com" { - t.Errorf("Expected 'https://newexample.com', got %v", req.Uri()) - } -} - -func TestSetUriHandlesInvalidInput(t *testing.T) { - req := NewReq("https://example.com") - err := req.SetUri("://invalidurl") - if err == nil { - t.Errorf("Expected error, got nil") - } -} - -func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) { - req := NewReq("https://example.com") - req.SetUriNoError("https://newexample.com") - if req.Uri() != "https://newexample.com" { - t.Errorf("Expected 'https://newexample.com', got %v", req.Uri()) - } -} - -func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) { - req := NewReq("https://example.com") - req.SetUriNoError("://invalidurl") - if req.Uri() != "https://example.com" { - t.Errorf("Expected 'https://example.com', got %v", req.Uri()) - } -} - -type postmanReply struct { - Args struct { - } `json:"args"` - Form map[string]string `json:"form"` - Headers map[string]string `json:"headers"` - Url string `json:"url"` -} - -func TestGet(t *testing.T) { - var reply postmanReply - resp, err := NewReq("https://postman-echo.com/get"). - AddHeader("hello", "nononmo"). - SetAutoCalcContentLengthNoError(true).Do() - if err != nil { - t.Error(err) - } - fmt.Println(resp.Proto) - err = resp.Body().Unmarshal(&reply) - if err != nil { - t.Error(err) - } - fmt.Println(resp.Body().String()) - fmt.Println(reply.Headers) - fmt.Println(resp.Cookies()) -} - -type testData struct { - name string - args *Request - want func(*Response) error - wantErr bool -} - -func headerTestData() []testData { - return []testData{ - { - name: "addHeader", - args: NewReq("https://postman-echo.com/get"). - AddHeader("b612", "test-data"). - AddHeader("b612", "test-header"). - AddSimpleCookie("b612", "test-cookie"). - SetHeader("User-Agent", "starnet test"), - want: func(resp *Response) error { - //fmt.Println(resp.Body().String()) - if resp == nil { - return fmt.Errorf("response is nil") - } - if resp.StatusCode != 200 { - return fmt.Errorf("status code is %d", resp.StatusCode) - } - var reply postmanReply - err := resp.Body().Unmarshal(&reply) - if err != nil { - return err - } - if reply.Headers["b612"] != "test-data, test-header" { - return fmt.Errorf("header not found") - } - if reply.Headers["user-agent"] != "starnet test" { - return fmt.Errorf("user-agent not found") - } - if reply.Headers["cookie"] != "b612=test-cookie" { - return fmt.Errorf("cookie not found") - } - return nil - }, - wantErr: false, - }, - { - name: "postForm", - args: NewSimpleRequest("https://postman-echo.com/post", "POST"). - AddHeader("b612", "test-data"). - AddHeader("b612", "test-header"). - AddSimpleCookie("b612", "test-cookie"). - SetHeader("User-Agent", "starnet test"). - //SetHeader("Content-Type", "application/x-www-form-urlencoded"). - AddFormData("hello", "world"). - AddFormData("hello2", "world2"). - SetMethodNoError("POST"), - want: func(resp *Response) error { - //fmt.Println(resp.Body().String()) - if resp == nil { - return fmt.Errorf("response is nil") - } - if resp.StatusCode != 200 { - return fmt.Errorf("status code is %d", resp.StatusCode) - } - var reply postmanReply - err := resp.Body().Unmarshal(&reply) - if err != nil { - return err - } - if reply.Headers["b612"] != "test-data, test-header" { - return fmt.Errorf("header not found") - } - if reply.Headers["user-agent"] != "starnet test" { - return fmt.Errorf("user-agent not found") - } - if reply.Headers["cookie"] != "b612=test-cookie" { - return fmt.Errorf("cookie not found") - } - if reply.Form["hello"] != "world" { - return fmt.Errorf("form data not found") - } - if reply.Form["hello2"] != "world2" { - return fmt.Errorf("form data not found") - } - return nil - }, - wantErr: false, - }, - } -} -func TestCurl(t *testing.T) { - for _, tt := range headerTestData() { - t.Run(tt.name, func(t *testing.T) { - got, err := Curl(tt.args) - if (err != nil) != tt.wantErr { - t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr) - return - } - if tt.want != nil { - if err := tt.want(got); err != nil { - t.Errorf("Curl() = %v", err) - } - } - }) - } -} - -func TestReqClone(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Header.Get("hello") != "world" { - rw.WriteHeader(http.StatusBadRequest) - rw.Write([]byte("hello world failed")) - return - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world")) - resp, err := req.Do() - if err != nil { - t.Error(err) - } - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() - req = req.Clone() - req.AddHeader("ok", "good") - resp, err = req.Do() - if err != nil { - t.Error(err) - } - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() -} - -func TestUploadFile(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Header.Get("hello") != "world" { - rw.WriteHeader(http.StatusBadRequest) - rw.Write([]byte("hello world failed")) - return - } - files, header, err := req.FormFile("666") - if err == nil { - fmt.Println(header.Filename) - fmt.Println(header.Size) - fmt.Println(files.Close()) - } - files, header, err = req.FormFile("777") - if err == nil { - fmt.Println(header.Filename) - fmt.Println(header.Size) - fmt.Println(files.Close()) - } - files, header, err = req.FormFile("888") - if err == nil { - fmt.Println(header.Filename) - fmt.Println(header.Size) - fmt.Println(files.Close()) - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world")) - req.AddFileWithName("666", "./curl.go", "curl.go") - req.AddFile("777", "./go.mod") - req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html") - resp, err := req.Do() - if err != nil { - t.Error(err) - } - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() - req = req.Clone() - req.AddHeader("ok", "good") - - resp, err = req.Do() - if err != nil { - t.Error(err) - } - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() -} - -func TestTlsConfig(t *testing.T) { - server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Header.Get("hello") != "world" { - rw.WriteHeader(http.StatusBadRequest) - rw.Write([]byte("hello world failed")) - return - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - client, err := NewHttpClient(WithSkipTLSVerify(false)) - if err != nil { - t.Error(err) - } - req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world")) - //SetClientSkipVerify(client, true) - //req.SetDoRawClient(false) - //req.SetDoRawTransport(false) - req.SetSkipTLSVerify(true) - req.SetProxy("http://127.0.0.1:29992") - resp, err := req.Do() - if err != nil { - t.Error(err) - } - fmt.Println(resp.Proto) - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() - req = req.Clone() - req.AddHeader("ok", "good") - resp, err = req.Do() - if err != nil { - t.Error(err) - } - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() - req = req.Clone() - req.SetSkipTLSVerify(false) - resp, err = req.Do() - if err == nil { - t.Error(err) - } -} - -func TestHttpPostAndChunked(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { - t.Errorf("Expected 'POST', got %v", req.Method) - } - buf := make([]byte, 1024) - n, _ := req.Body.Read(buf) - if string(buf[:n]) != "hello world" { - t.Errorf("Expected body to be 'hello world', got %s", string(buf[:n])) - } - - if req.Header.Get("chunked") == "true" { - if req.TransferEncoding[0] != "chunked" { - t.Errorf("Expected Transfer-Encoding to be 'chunked', got %s", req.Header.Get("Transfer-Encoding")) - } - } else { - if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" { - t.Errorf("Expected Transfer-Encoding to not be 'chunked', got %s", req.Header.Get("Transfer-Encoding")) - } - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - resp, err := Post(server.URL, WithBytes([]byte("hello world")), WithContentLength(-1), WithHeader("Content-Type", "text/plain"), - WithHeader("chunked", "true")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - body := resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } - resp.Close() - - resp, err = Post(server.URL, WithBytes([]byte("hello world")), WithHeader("Content-Type", "text/plain"), - WithHeader("chunked", "false")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - defer resp.Close() - body = resp.Body().String() - if body != "OK" { - t.Errorf("Expected OK, got %v", body) - } -} - -func TestWithTimeout(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - time.Sleep(time.Second * 30) - rw.Write([]byte(`OK`)) - })) - funcList := []func(string, ...RequestOpt) (*Response, error){ - Get, - Post, - Put, - Delete, - Options, - Patch, - Head, - Trace, - Connect, - } - defer server.Close() - for i := 1; i < 30; i++ { - go func(i int) { - old := time.Now() - fn := funcList[i%len(funcList)] - resp, err := fn(server.URL, WithTimeout(time.Second*time.Duration(i))) - if time.Since(old) > time.Second*time.Duration(i+2) || time.Since(old) < time.Second*time.Duration(i) { - t.Errorf("timeout not work") - } - fmt.Println(time.Since(old)) - if err == nil { - t.Error(err) - resp.CloseAll() - } else { - fmt.Println(err) - } - }(i) - } - resp, err := Get(server.URL, WithTimeout(time.Second*60)) - if err != nil { - t.Error(err) - } else { - fmt.Println(resp.Body().String()) - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() - } -} - -func TestConfigWithClient(t *testing.T) { - server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Header.Get("hello") != "world" { - rw.WriteHeader(http.StatusBadRequest) - rw.Write([]byte("hello world failed")) - return - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - client, err := NewHttpClient(WithSkipTLSVerify(true)) - if err != nil { - t.Error(err) - } - req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world")) - //SetClientSkipVerify(client, true) - //req.SetDoRawClient(false) - //req.SetDoRawTransport(false) - resp, err := req.Do() - if err != nil { - t.Error(err) - } - fmt.Println(resp.Proto) - if resp.StatusCode != 200 { - resp.CloseAll() - t.Errorf("status code is %d", resp.StatusCode) - } - resp.CloseAll() -} diff --git a/curl_transport.go b/curl_transport.go deleted file mode 100644 index 943f1fe..0000000 --- a/curl_transport.go +++ /dev/null @@ -1,178 +0,0 @@ -package starnet - -import ( - "context" - "crypto/tls" - "fmt" - "net/http" - "reflect" -) - -type Client struct { - *http.Client - opts []RequestOpt -} - -func (c Client) Options() []RequestOpt { - return c.opts -} - -func (c Client) SetOptions(opts ...RequestOpt) Client { - return Client{ - Client: c.Client, - opts: opts, - } -} - -// NewHttpClient creates a new http.Client with the specified options. -func NewHttpClient(opts ...RequestOpt) (Client, error) { - req, err := newRequest(context.Background(), "", "", opts...) - if err != nil { - return Client{}, err - } - defer func() { - req = nil - }() - cl, err := req.HttpClient() - return Client{ - Client: cl, - opts: opts, - }, err -} - -func NewHttpClientNoErr(opts ...RequestOpt) Client { - c, _ := NewHttpClient(opts...) - return c -} - -func NewClientFromHttpClient(httpClient *http.Client) (Client, error) { - if httpClient == nil { - return Client{}, fmt.Errorf("httpClient cannot be nil") - } - - if httpClient.Transport == nil { - httpClient.Transport = &Transport{ - base: &http.Transport{}, - } - } else { - switch t := httpClient.Transport.(type) { - case *Transport: - if t.base == nil { - t.base = &http.Transport{} - } - case *http.Transport: - httpClient.Transport = &Transport{ - base: t, - } - default: - return Client{}, fmt.Errorf("unsupported transport type: %T", t) - } - } - return Client{ - Client: httpClient, - }, nil -} - -func NewClientFromHttpClientNoError(httpClient *http.Client) Client { - return Client{Client: httpClient} -} - -// DisableRedirect returns whether the request will disable HTTP redirects. -// if true, the request will not follow redirects automatically. -// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect. -// you will get the original response with the redirect status code and Location header. -func (c Client) DisableRedirect() bool { - return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer() -} - -// SetDisableRedirect sets whether the request will disable HTTP redirects. -// if true, the request will not follow redirects automatically. -// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect. -// you will get the original response with the redirect status code and Location header. -func (c Client) SetDisableRedirect(disableRedirect bool) { - if disableRedirect { - c.Client.CheckRedirect = DefaultCheckRedirectFunc - } -} - -func (c Client) SetDefaultSkipTLSVerify(skip bool) { - if c.Client.Transport == nil { - c.Client.Transport = &Transport{ - base: &http.Transport{}, - } - } - if transport, ok := c.Client.Transport.(*Transport); ok { - if transport.base.TLSClientConfig == nil { - transport.base.TLSClientConfig = &tls.Config{} - } - transport.base.TLSClientConfig.InsecureSkipVerify = skip - } else if transport, ok := c.Client.Transport.(*http.Transport); ok { - if transport.TLSClientConfig == nil { - transport.TLSClientConfig = &tls.Config{} - } - transport.TLSClientConfig.InsecureSkipVerify = skip - } -} - -func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) { - if c.Client.Transport == nil { - c.Client.Transport = &Transport{ - base: &http.Transport{}, - } - } - if transport, ok := c.Client.Transport.(*Transport); ok { - transport.base.TLSClientConfig = tlsConfig - } else if transport, ok := c.Client.Transport.(*http.Transport); ok { - transport.TLSClientConfig = tlsConfig - } -} - -func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { - if c.Client == nil { - return nil, fmt.Errorf("http client is nil") - } - req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...) - return req, err -} - -func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { - if c.Client == nil { - return nil, fmt.Errorf("http client is nil") - } - req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...) - return req, err -} - -func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { - req, _ := c.NewRequest(url, method, opts...) - return req -} - -func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { - req, _ := c.NewRequestContext(ctx, url, method, opts...) - return req -} - -type Transport struct { - base *http.Transport -} - -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - if t.base == nil { - t.base = &http.Transport{} - } - transport, ok := req.Context().Value("transport").(*http.Transport) - if ok && transport != nil { - return transport.RoundTrip(req) - } - proxy, ok := req.Context().Value("proxy").(string) - if ok && proxy != "" { - tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config) - if ok && tlsConfig != nil { - tmpTransport := t.base.Clone() - tmpTransport.TLSClientConfig = tlsConfig - return tmpTransport.RoundTrip(req) - } - } - return t.base.RoundTrip(req) -} diff --git a/curlbench_test.go b/curlbench_test.go deleted file mode 100644 index 3894989..0000000 --- a/curlbench_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package starnet - -import ( - "fmt" - "net/http" - "net/http/httptest" - "runtime" - "testing" -) - -// BenchmarkGetRequest 测试单个 GET 请求的性能 -func BenchmarkGetRequest(b *testing.B) { - // 创建测试服务器 - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - // 重置计时器,排除设置代码的影响 - b.ResetTimer() - - // 报告内存分配情况 - b.ReportAllocs() - - // 运行基准测试 - for i := 0; i < b.N; i++ { - resp, err := Get(server.URL, WithSkipTLSVerify(true)) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - b.Errorf("Expected OK, got %v", body) - } - } -} - -// BenchmarkGetRequestWithHeaders 测试带请求头的 GET 请求性能 -func BenchmarkGetRequestWithHeaders(b *testing.B) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - // 验证请求头 - if req.Header.Get("hello") != "world" { - rw.WriteHeader(http.StatusBadRequest) - return - } - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - resp, err := Get(server.URL, - WithSkipTLSVerify(true), - WithHeader("hello", "world"), - WithUserAgent("hello world")) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - b.Errorf("Expected OK, got %v", body) - } - } -} - -// BenchmarkPostRequest 测试 POST 请求的性能 -func BenchmarkPostRequest(b *testing.B) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - // 读取并返回请求体 - body := make([]byte, req.ContentLength) - req.Body.Read(body) - rw.Write(body) - })) - defer server.Close() - - testData := "This is a test payload for POST request" - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - resp, err := Post(server.URL, - WithSkipTLSVerify(true), - WithBytes([]byte(testData)), - WithContentType("text/plain")) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != testData { - b.Errorf("Expected %s, got %v", testData, body) - } - } -} - -// BenchmarkConcurrentRequests 测试并发请求性能 -func BenchmarkConcurrentRequests(b *testing.B) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - b.ResetTimer() - b.ReportAllocs() - - // 运行并发基准测试 - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - resp, err := Get(server.URL, WithSkipTLSVerify(true)) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - b.Errorf("Expected OK, got %v", body) - } - } - }) -} - -// BenchmarkMemoryUsage 专门测试内存使用情况 -func BenchmarkMemoryUsage(b *testing.B) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Write([]byte(`OK`)) - })) - defer server.Close() - - // 禁用默认的测试时间,只关注内存分配 - b.ReportAllocs() - - var memStatsStart, memStatsEnd runtime.MemStats - runtime.GC() - runtime.ReadMemStats(&memStatsStart) - - for i := 0; i < b.N; i++ { - resp, err := Get(server.URL, WithSkipTLSVerify(true)) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().String() - if body != "OK" { - b.Errorf("Expected OK, got %v", body) - } - } - - runtime.GC() - runtime.ReadMemStats(&memStatsEnd) - - // 计算每次操作的平均内存分配 - allocsPerOp := float64(memStatsEnd.Mallocs-memStatsStart.Mallocs) / float64(b.N) - bytesPerOp := float64(memStatsEnd.TotalAlloc-memStatsStart.TotalAlloc) / float64(b.N) - - b.ReportMetric(allocsPerOp, "allocs/op") - b.ReportMetric(bytesPerOp, "bytes/op") -} - -// BenchmarkDifferentResponseSizes 测试不同响应大小的性能 -func BenchmarkDifferentResponseSizes(b *testing.B) { - // 测试不同大小的响应 - responseSizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB - - for _, size := range responseSizes { - // 生成指定大小的响应数据 - responseData := make([]byte, size) - for i := 0; i < size; i++ { - responseData[i] = 'A' - } - - b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) { - server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Write(responseData) - })) - defer server.Close() - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - resp, err := Get(server.URL, WithSkipTLSVerify(true)) - if err != nil { - b.Errorf("Unexpected error: %v", err) - } - - body := resp.Body().Bytes() - if len(body) != size { - b.Errorf("Expected size %d, got %d", size, len(body)) - } - } - }) - } -} diff --git a/defaults.go b/defaults.go new file mode 100644 index 0000000..bf77dc5 --- /dev/null +++ b/defaults.go @@ -0,0 +1,147 @@ +package starnet + +import ( + "net/http" + "sync" + "time" +) + +var ( + defaultClient *Client + defaultHTTPClient *http.Client + defaultClientOnce sync.Once + defaultHTTPOnce sync.Once + + defaultMu sync.RWMutex +) + +// DefaultClient 获取默认 Client(单例) +func DefaultClient() *Client { + defaultMu.RLock() + if defaultClient != nil { + c := defaultClient + defaultMu.RUnlock() + return c + } + defaultMu.RUnlock() + + defaultClientOnce.Do(func() { + c := NewClientNoErr() + defaultMu.Lock() + defaultClient = c + defaultMu.Unlock() + }) + + defaultMu.RLock() + c := defaultClient + defaultMu.RUnlock() + return c +} + +// DefaultHTTPClient 获取默认 http.Client(单例) +func DefaultHTTPClient() *http.Client { + defaultMu.RLock() + if defaultHTTPClient != nil { + c := defaultHTTPClient + defaultMu.RUnlock() + return c + } + defaultMu.RUnlock() + + defaultHTTPOnce.Do(func() { + c := &http.Client{ + Transport: &Transport{ + base: &http.Transport{ + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, + Timeout: 0, // 由请求级控制超时 + } + defaultMu.Lock() + defaultHTTPClient = c + defaultMu.Unlock() + }) + + defaultMu.RLock() + c := defaultHTTPClient + defaultMu.RUnlock() + return c +} + +// SetDefaultClient 设置默认 Client +func SetDefaultClient(client *Client) { + defaultMu.Lock() + defer defaultMu.Unlock() + + defaultClient = client + // 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖 + defaultClientOnce.Do(func() {}) +} + +// SetDefaultHTTPClient 设置默认 http.Client +func SetDefaultHTTPClient(client *http.Client) { + defaultMu.Lock() + defer defaultMu.Unlock() + + defaultHTTPClient = client + // 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖 + defaultHTTPOnce.Do(func() {}) +} + +// Get 发送 GET 请求(使用默认 Client) +func Get(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Get(url, opts...) +} + +// Post 发送 POST 请求(使用默认 Client) +func Post(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Post(url, opts...) +} + +// Put 发送 PUT 请求(使用默认 Client) +func Put(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Put(url, opts...) +} + +// Delete 发送 DELETE 请求(使用默认 Client) +func Delete(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Delete(url, opts...) +} + +// Head 发送 HEAD 请求(使用默认 Client) +func Head(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Head(url, opts...) +} + +// Patch 发送 PATCH 请求(使用默认 Client) +func Patch(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Patch(url, opts...) +} + +// Options 发送 OPTIONS 请求(使用默认 Client) +func Options(url string, opts ...RequestOpt) (*Response, error) { + return DefaultClient().Options(url, opts...) +} + +// Trace 发送 TRACE 请求(使用默认 Client) +func Trace(url string, opts ...RequestOpt) (*Response, error) { + req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Connect 发送 CONNECT 请求(使用默认 Client) +func Connect(url string, opts ...RequestOpt) (*Response, error) { + req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...) + if err != nil { + return nil, err + } + return req.Do() +} diff --git a/dialer.go b/dialer.go new file mode 100644 index 0000000..1f84667 --- /dev/null +++ b/dialer.go @@ -0,0 +1,148 @@ +package starnet + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" +) + +// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS) +func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) { + // 提取配置 + reqCtx := getRequestContext(ctx) + + dialTimeout := reqCtx.DialTimeout + if dialTimeout == 0 { + dialTimeout = DefaultDialTimeout + } + + timeout := reqCtx.Timeout + if timeout == 0 { + timeout = DefaultTimeout + } + + // 解析地址 + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, wrapError(err, "split host port") + } + + // 获取 IP 地址列表 + var addrs []string + + // 优先级1:直接指定的 IP + if len(reqCtx.CustomIP) > 0 { + for _, ip := range reqCtx.CustomIP { + addrs = append(addrs, net.JoinHostPort(ip, port)) + } + } else { + // 优先级2:DNS 解析 + var ipAddrs []net.IPAddr + + // 使用自定义解析函数 + if reqCtx.LookupIPFn != nil { + ipAddrs, err = reqCtx.LookupIPFn(ctx, host) + } else if len(reqCtx.CustomDNS) > 0 { + // 使用自定义 DNS 服务器 + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var lastErr error + for _, dnsServer := range reqCtx.CustomDNS { + conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53")) + if err != nil { + lastErr = err + continue + } + return conn, nil + } + return nil, lastErr + }, + } + ipAddrs, err = resolver.LookupIPAddr(ctx, host) + } else { + // 使用默认解析器 + ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host) + } + + if err != nil { + return nil, wrapError(err, "lookup ip") + } + + for _, ipAddr := range ipAddrs { + addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port)) + } + } + + // 尝试连接所有地址 + var lastErr error + for _, addr := range addrs { + conn, err := net.DialTimeout(network, addr, dialTimeout) + if err != nil { + lastErr = err + continue + } + + // 设置总超时 + if timeout > 0 { + conn.SetDeadline(time.Now().Add(timeout)) + } + + return conn, nil + } + + if lastErr != nil { + return nil, wrapError(lastErr, "dial all addresses failed") + } + + return nil, fmt.Errorf("no addresses to dial") +} + +// defaultDialTLSFunc 默认 TLS Dial 函数 +func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) { + // 先建立 TCP 连接 + conn, err := defaultDialFunc(ctx, network, addr) + if err != nil { + return nil, err + } + + // 提取 TLS 配置 + reqCtx := getRequestContext(ctx) + tlsConfig := reqCtx.TLSConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{} + } + + // 执行 TLS 握手 + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + return nil, wrapError(err, "tls handshake") + } + + return tlsConn, nil +} + +/* +// defaultProxyFunc 默认代理函数 +func defaultProxyFunc(req *http.Request) (*url.URL, error) { + if req == nil { + return nil, fmt.Errorf("request is nil") + } + + reqCtx := getRequestContext(req.Context()) + if reqCtx.Proxy == "" { + return nil, nil + } + + proxyURL, err := url.Parse(reqCtx.Proxy) + if err != nil { + return nil, wrapError(err, "parse proxy url") + } + + return proxyURL, nil +} + +*/ diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 0000000..75fa99b --- /dev/null +++ b/dns_test.go @@ -0,0 +1,103 @@ +package starnet + +import ( + "context" + "net" + "testing" +) + +func TestRequestCustomIP(t *testing.T) { + customIPs := []string{"1.2.3.4", "5.6.7.8"} + req := NewSimpleRequest("http://example.com", "GET"). + SetCustomIP(customIPs) + + if len(req.config.DNS.CustomIP) != 2 { + t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP)) + } + + for i, ip := range req.config.DNS.CustomIP { + if ip != customIPs[i] { + t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i]) + } + } +} + +func TestRequestCustomIPInvalid(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET"). + SetCustomIP([]string{"invalid-ip"}) + + if req.Err() == nil { + t.Error("Expected error for invalid IP, got nil") + } +} + +func TestRequestCustomDNS(t *testing.T) { + dnsServers := []string{"8.8.8.8", "1.1.1.1"} + req := NewSimpleRequest("http://example.com", "GET"). + SetCustomDNS(dnsServers) + + if len(req.config.DNS.CustomDNS) != 2 { + t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS)) + } +} + +func TestRequestCustomDNSInvalid(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET"). + SetCustomDNS([]string{"invalid-dns"}) + + if req.Err() == nil { + t.Error("Expected error for invalid DNS, got nil") + } +} + +func TestRequestLookupFunc(t *testing.T) { + called := false + lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) { + called = true + return []net.IPAddr{ + {IP: net.ParseIP("1.2.3.4")}, + }, nil + } + + req := NewSimpleRequest("http://example.com", "GET"). + SetLookupFunc(lookupFunc) + + if req.config.DNS.LookupFunc == nil { + t.Error("LookupFunc not set") + } + + // Call the function to verify it works + ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com") + if err != nil { + t.Errorf("LookupFunc error: %v", err) + } + if !called { + t.Error("LookupFunc was not called") + } + if len(ips) != 1 { + t.Errorf("IPs length = %v; want 1", len(ips)) + } +} + +func TestDNSPriority(t *testing.T) { + // CustomIP should have highest priority + req := NewSimpleRequest("http://example.com", "GET"). + SetCustomIP([]string{"1.2.3.4"}). + SetCustomDNS([]string{"8.8.8.8"}). + SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil + }) + + // CustomIP should be set + if len(req.config.DNS.CustomIP) == 0 { + t.Error("CustomIP should be set") + } + + // Others should also be set (but CustomIP takes priority in actual use) + if len(req.config.DNS.CustomDNS) == 0 { + t.Error("CustomDNS should be set") + } + if req.config.DNS.LookupFunc == nil { + t.Error("LookupFunc should be set") + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..3bba190 --- /dev/null +++ b/errors.go @@ -0,0 +1,58 @@ +package starnet + +import ( + "errors" + "fmt" +) + +var ( + // ErrInvalidMethod 无效的 HTTP 方法 + ErrInvalidMethod = errors.New("starnet: invalid HTTP method") + + // ErrInvalidURL 无效的 URL + ErrInvalidURL = errors.New("starnet: invalid URL") + + // ErrInvalidIP 无效的 IP 地址 + ErrInvalidIP = errors.New("starnet: invalid IP address") + + // ErrInvalidDNS 无效的 DNS 服务器 + ErrInvalidDNS = errors.New("starnet: invalid DNS server") + + // ErrNilClient HTTP Client 为 nil + ErrNilClient = errors.New("starnet: http client is nil") + + // ErrNilReader Reader 为 nil + ErrNilReader = errors.New("starnet: reader is nil") + + // ErrFileNotFound 文件不存在 + ErrFileNotFound = errors.New("starnet: file not found") + + // ErrRequestNotPrepared 请求未准备好 + ErrRequestNotPrepared = errors.New("starnet: request not prepared") + + // ErrBodyAlreadyConsumed Body 已被消费 + ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed") +) + +// wrapError 包装错误,添加上下文信息 +func wrapError(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + msg := fmt.Sprintf(format, args...) + return fmt.Errorf("%s: %w", msg, err) +} + +var ( + // ErrNilConn indicates a nil net.Conn argument. + ErrNilConn = errors.New("starnet: nil connection") + + // ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden. + ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed") + + // ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP. + ErrNotTLS = errors.New("starnet: connection is not TLS") + + // ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available. + ErrNoTLSConfig = errors.New("starnet: no TLS config available") +) diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..6bfdecd --- /dev/null +++ b/example_test.go @@ -0,0 +1,200 @@ +package starnet_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "time" + + "b612.me/starnet" +) + +func ExampleGet() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, World!")) + })) + defer server.Close() + + resp, err := starnet.Get(server.URL) + if err != nil { + panic(err) + } + defer resp.Close() + + body, _ := resp.Body().String() + fmt.Println(body) + // Output: Hello, World! +} + +func ExamplePost() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Posted")) + })) + defer server.Close() + + resp, err := starnet.Post(server.URL, + starnet.WithBodyString("test data")) + if err != nil { + panic(err) + } + defer resp.Close() + + body, _ := resp.Body().String() + fmt.Println(body) + // Output: Posted +} + +func ExampleNewSimpleRequest() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + req := starnet.NewSimpleRequest(server.URL, "GET"). + SetHeader("X-Custom", "value"). + AddQuery("name", "test") + + resp, err := req.Do() + if err != nil { + panic(err) + } + defer resp.Close() + + fmt.Println(resp.StatusCode) + // Output: 200 +} + +func ExampleClient_Get() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Client GET")) + })) + defer server.Close() + + client := starnet.NewClientNoErr( + starnet.WithTimeout(10*time.Second), + starnet.WithUserAgent("MyApp/1.0"), + ) + + resp, err := client.Get(server.URL) + if err != nil { + panic(err) + } + defer resp.Close() + + body, _ := resp.Body().String() + fmt.Println(body) + // Output: Client GET +} + +func ExampleRequest_SetJSON() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok"}`)) + })) + defer server.Close() + + type User struct { + Name string `json:"name"` + Email string `json:"email"` + } + + user := User{Name: "John", Email: "john@example.com"} + + resp, err := starnet.NewSimpleRequest(server.URL, "POST"). + SetJSON(user). + Do() + if err != nil { + panic(err) + } + defer resp.Close() + + var result map[string]string + resp.Body().JSON(&result) + fmt.Println(result["status"]) + // Output: ok +} + +func ExampleRequest_AddFormData() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + fmt.Fprintf(w, "name=%s", r.FormValue("name")) + })) + defer server.Close() + + resp, err := starnet.NewSimpleRequest(server.URL, "POST"). + AddFormData("name", "John"). + AddFormData("age", "30"). + Do() + if err != nil { + panic(err) + } + defer resp.Close() + + body, _ := resp.Body().String() + fmt.Println(body) + // Output: name=John +} + +func ExampleRequest_SetSkipTLSVerify() { + // This example shows how to skip TLS verification + // Useful for testing with self-signed certificates + + req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET"). + SetSkipTLSVerify(true) + + // In a real scenario, you would call req.Do() + fmt.Println(req.Method()) + // Output: GET +} + +func ExampleRequest_Clone() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + defer server.Close() + + baseReq := starnet.NewSimpleRequest(server.URL, "GET"). + SetHeader("X-API-Key", "secret") + + // Clone and modify + req1 := baseReq.Clone().AddQuery("page", "1") + req2 := baseReq.Clone().AddQuery("page", "2") + + resp1, _ := req1.Do() + resp2, _ := req2.Do() + + defer resp1.Close() + defer resp2.Close() + + fmt.Println(resp1.StatusCode, resp2.StatusCode) + // Output: 200 200 +} + +func ExampleClient_SetDefaultSkipTLSVerify() { + client := starnet.NewClientNoErr() + client.SetDefaultSkipTLSVerify(true) + + // All requests from this client will skip TLS verification + // unless overridden at request level + + fmt.Println("Client configured") + // Output: Client configured +} + +func ExampleWithTimeout() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.Write([]byte("OK")) + })) + defer server.Close() + + resp, err := starnet.Get(server.URL, + starnet.WithTimeout(200*time.Millisecond)) + if err != nil { + panic(err) + } + defer resp.Close() + + fmt.Println(resp.StatusCode) + // Output: 200 +} diff --git a/file_upload_test.go b/file_upload_test.go new file mode 100644 index 0000000..f532dbd --- /dev/null +++ b/file_upload_test.go @@ -0,0 +1,172 @@ +package starnet + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRequestAddFileStream(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(10 << 20) // 10 MB + if err != nil { + t.Fatalf("ParseMultipartForm error: %v", err) + } + + file, header, err := r.FormFile("file") + if err != nil { + t.Fatalf("FormFile error: %v", err) + } + defer file.Close() + + content, _ := io.ReadAll(file) + w.Write([]byte(header.Filename + ":" + string(content))) + })) + defer server.Close() + + fileContent := "test file content" + reader := strings.NewReader(fileContent) + + req := NewSimpleRequest(server.URL, "POST"). + AddFileStream("file", "test.txt", int64(len(fileContent)), reader) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + expected := "test.txt:" + fileContent + if body != expected { + t.Errorf("Body = %v; want %v", body, expected) + } +} + +func TestRequestAddFileWithFormData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(10 << 20) + if err != nil { + t.Fatalf("ParseMultipartForm error: %v", err) + } + + // Check form field + name := r.FormValue("name") + if name != "John" { + t.Errorf("name = %v; want John", name) + } + + // Check file + file, header, err := r.FormFile("file") + if err != nil { + t.Fatalf("FormFile error: %v", err) + } + defer file.Close() + + w.Write([]byte("OK:" + header.Filename)) + })) + defer server.Close() + + fileContent := "file data" + reader := strings.NewReader(fileContent) + + req := NewSimpleRequest(server.URL, "POST"). + AddFormData("name", "John"). + AddFileStream("file", "document.txt", int64(len(fileContent)), reader) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if !strings.Contains(body, "document.txt") { + t.Errorf("Body should contain filename, got: %v", body) + } +} + +func TestRequestUploadProgress(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseMultipartForm(10 << 20) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + progressCalled := false + var lastUploaded int64 + + fileContent := strings.Repeat("a", 1024*10) // 10KB + reader := strings.NewReader(fileContent) + + req := NewSimpleRequest(server.URL, "POST"). + SetUploadProgress(func(filename string, uploaded, total int64) { + progressCalled = true + lastUploaded = uploaded + if filename != "test.txt" { + t.Errorf("filename = %v; want test.txt", filename) + } + }). + AddFileStream("file", "test.txt", int64(len(fileContent)), reader) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if !progressCalled { + t.Error("Progress callback was not called") + } + + if lastUploaded != int64(len(fileContent)) { + t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent)) + } +} + +// TestRequestAddFileFromDisk tests uploading a real file from disk +func TestRequestAddFileFromDisk(t *testing.T) { + // Create a temporary file + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + fileContent := []byte("test file content from disk") + + err := os.WriteFile(tmpFile, fileContent, 0644) + if err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(10 << 20) + if err != nil { + t.Fatalf("ParseMultipartForm error: %v", err) + } + + file, header, err := r.FormFile("file") + if err != nil { + t.Fatalf("FormFile error: %v", err) + } + defer file.Close() + + content, _ := io.ReadAll(file) + w.Write([]byte(header.Filename + ":" + string(content))) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile) + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if !strings.Contains(body, string(fileContent)) { + t.Errorf("Body should contain file content, got: %v", body) + } +} diff --git a/header_test.go b/header_test.go new file mode 100644 index 0000000..59431fb --- /dev/null +++ b/header_test.go @@ -0,0 +1,140 @@ +package starnet + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestHeaders(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + json.NewEncoder(w).Encode(headers) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetHeader("X-Custom-Header", "value1"). + AddHeader("X-Multi-Header", "value1"). + AddHeader("X-Multi-Header", "value2") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var headers map[string]string + resp.Body().JSON(&headers) + + if headers["X-Custom-Header"] != "value1" { + t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"]) + } +} + +func TestRequestCookies(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookies := make(map[string]string) + for _, cookie := range r.Cookies() { + cookies[cookie.Name] = cookie.Value + } + json.NewEncoder(w).Encode(cookies) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + AddSimpleCookie("session", "abc123"). + AddSimpleCookie("user", "john") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var cookies map[string]string + resp.Body().JSON(&cookies) + + if cookies["session"] != "abc123" { + t.Errorf("session cookie = %v; want abc123", cookies["session"]) + } + if cookies["user"] != "john" { + t.Errorf("user cookie = %v; want john", cookies["user"]) + } +} + +func TestRequestUserAgent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.UserAgent())) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetUserAgent("CustomAgent/1.0") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "CustomAgent/1.0" { + t.Errorf("User-Agent = %v; want CustomAgent/1.0", body) + } +} + +func TestRequestBearerToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + w.Write([]byte(auth)) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetBearerToken("mytoken123") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + expected := "Bearer mytoken123" + if body != expected { + t.Errorf("Authorization = %v; want %v", body, expected) + } +} + +func TestRequestBasicAuth(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + w.Write([]byte(username + ":" + password)) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetBasicAuth("user", "pass") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "user:pass" { + t.Errorf("BasicAuth = %v; want user:pass", body) + } +} diff --git a/httpguts.go b/httpguts.go deleted file mode 100644 index baf8d38..0000000 --- a/httpguts.go +++ /dev/null @@ -1,120 +0,0 @@ -package starnet - -import "strings" - -var isTokenTable = [127]bool{ - '!': true, - '#': true, - '$': true, - '%': true, - '&': true, - '\'': true, - '*': true, - '+': true, - '-': true, - '.': true, - '0': true, - '1': true, - '2': true, - '3': true, - '4': true, - '5': true, - '6': true, - '7': true, - '8': true, - '9': true, - 'A': true, - 'B': true, - 'C': true, - 'D': true, - 'E': true, - 'F': true, - 'G': true, - 'H': true, - 'I': true, - 'J': true, - 'K': true, - 'L': true, - 'M': true, - 'N': true, - 'O': true, - 'P': true, - 'Q': true, - 'R': true, - 'S': true, - 'T': true, - 'U': true, - 'W': true, - 'V': true, - 'X': true, - 'Y': true, - 'Z': true, - '^': true, - '_': true, - '`': true, - 'a': true, - 'b': true, - 'c': true, - 'd': true, - 'e': true, - 'f': true, - 'g': true, - 'h': true, - 'i': true, - 'j': true, - 'k': true, - 'l': true, - 'm': true, - 'n': true, - 'o': true, - 'p': true, - 'q': true, - 'r': true, - 's': true, - 't': true, - 'u': true, - 'v': true, - 'w': true, - 'x': true, - 'y': true, - 'z': true, - '|': true, - '~': true, -} - -func IsTokenRune(r rune) bool { - i := int(r) - return i < len(isTokenTable) && isTokenTable[i] -} - -func validMethod(method string) bool { - /* - Method = "OPTIONS" ; Section 9.2 - | "GET" ; Section 9.3 - | "HEAD" ; Section 9.4 - | "POST" ; Section 9.5 - | "PUT" ; Section 9.6 - | "DELETE" ; Section 9.7 - | "TRACE" ; Section 9.8 - | "CONNECT" ; Section 9.9 - | extension-method - extension-method = token - token = 1* - */ - return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 -} - -func isNotToken(r rune) bool { - return !IsTokenRune(r) -} - -func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } - -// removeEmptyPort strips the empty port in ":port" to "" -// as mandated by RFC 3986 Section 6.2.3. -func removeEmptyPort(host string) string { - if hasPort(host) { - return strings.TrimSuffix(host, ":") - } - return host -} diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..30a0cc3 --- /dev/null +++ b/integration_test.go @@ -0,0 +1,258 @@ +package starnet + +import ( + "os" + "testing" + "time" +) + +// 这些测试使用 httpbin.org 作为测试服务 +// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用 + +func skipIfNoIntegration(t *testing.T) { + if os.Getenv("STARNET_INTEGRATION_TEST") != "1" { + t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run") + } +} + +func TestIntegrationHTTPBinGet(t *testing.T) { + skipIfNoIntegration(t) + + resp, err := Get("https://httpbin.org/get", + WithQuery("name", "starnet"), + WithQuery("version", "1.0")) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != 200 { + t.Errorf("StatusCode = %v; want 200", resp.StatusCode) + } + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + args, ok := result["args"].(map[string]interface{}) + if !ok { + t.Fatal("args not found in response") + } + + if args["name"] != "starnet" { + t.Errorf("args[name] = %v; want starnet", args["name"]) + } +} + +func TestIntegrationHTTPBinPost(t *testing.T) { + skipIfNoIntegration(t) + + type PostData struct { + Name string `json:"name"` + Email string `json:"email"` + } + + data := PostData{ + Name: "John Doe", + Email: "john@example.com", + } + + resp, err := Post("https://httpbin.org/post", WithJSON(data)) + if err != nil { + t.Fatalf("Post() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != 200 { + t.Errorf("StatusCode = %v; want 200", resp.StatusCode) + } + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + jsonData, ok := result["json"].(map[string]interface{}) + if !ok { + t.Fatal("json not found in response") + } + + if jsonData["name"] != data.Name { + t.Errorf("name = %v; want %v", jsonData["name"], data.Name) + } +} + +func TestIntegrationHTTPBinHeaders(t *testing.T) { + skipIfNoIntegration(t) + + resp, err := Get("https://httpbin.org/headers", + WithHeader("X-Custom-Header", "test-value"), + WithUserAgent("Starnet-Test/1.0")) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + headers, ok := result["headers"].(map[string]interface{}) + if !ok { + t.Fatal("headers not found in response") + } + + if headers["X-Custom-Header"] != "test-value" { + t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"]) + } +} + +func TestIntegrationHTTPBinBasicAuth(t *testing.T) { + skipIfNoIntegration(t) + + resp, err := Get("https://httpbin.org/basic-auth/user/passwd", + WithBasicAuth("user", "passwd")) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != 200 { + t.Errorf("StatusCode = %v; want 200", resp.StatusCode) + } + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + if result["authenticated"] != true { + t.Error("authenticated should be true") + } +} + +func TestIntegrationHTTPBinDelay(t *testing.T) { + skipIfNoIntegration(t) + + // Test timeout + start := time.Now() + _, err := Get("https://httpbin.org/delay/3", + WithTimeout(1*time.Second)) + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected timeout error, got nil") + } + + if elapsed > 2*time.Second { + t.Errorf("Timeout took too long: %v", elapsed) + } +} + +func TestIntegrationHTTPBinRedirect(t *testing.T) { + skipIfNoIntegration(t) + + // Test with redirect enabled + client := NewClientNoErr() + resp, err := client.Get("https://httpbin.org/redirect/2") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != 200 { + t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode) + } + + // Test with redirect disabled + client.DisableRedirect() + resp2, err := client.Get("https://httpbin.org/redirect/2") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp2.Close() + + if resp2.StatusCode != 302 { + t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode) + } +} + +func TestIntegrationHTTPBinCookies(t *testing.T) { + skipIfNoIntegration(t) + + // 创建一个禁用重定向的 Client + client := NewClientNoErr() + client.DisableRedirect() + + resp, err := client.Get("https://httpbin.org/cookies/set?name=value") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + // 现在应该能获取到 Set-Cookie + cookies := resp.Cookies() + if len(cookies) == 0 { + t.Error("Expected cookies in response") + } + + // 验证 cookie + found := false + for _, cookie := range cookies { + if cookie.Name == "name" && cookie.Value == "value" { + found = true + break + } + } + if !found { + t.Error("Expected cookie 'name=value' not found") + } +} + +func TestIntegrationHTTPBinUserAgent(t *testing.T) { + skipIfNoIntegration(t) + + customUA := "Starnet-Integration-Test/1.0" + resp, err := Get("https://httpbin.org/user-agent", + WithUserAgent(customUA)) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + if result["user-agent"] != customUA { + t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA) + } +} + +func TestIntegrationHTTPBinGzip(t *testing.T) { + skipIfNoIntegration(t) + + resp, err := Get("https://httpbin.org/gzip") + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var result map[string]interface{} + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("JSON() error: %v", err) + } + + if result["gzipped"] != true { + t.Error("Response should be gzipped") + } +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..1b55197 --- /dev/null +++ b/options.go @@ -0,0 +1,390 @@ +package starnet + +import ( + "context" + "crypto/tls" + "encoding/json" + "io" + "net" + "net/http" + "os" + "time" +) + +// WithTimeout 设置请求总超时时间 +// timeout > 0: 使用该超时 +// timeout = 0: 使用 Client 默认超时 +// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0) +func WithTimeout(timeout time.Duration) RequestOpt { + return func(r *Request) error { + r.config.Network.Timeout = timeout + return nil + } +} + +// WithDialTimeout 设置连接超时时间 +func WithDialTimeout(timeout time.Duration) RequestOpt { + return func(r *Request) error { + r.config.Network.DialTimeout = timeout + return nil + } +} + +// WithProxy 设置代理 +func WithProxy(proxy string) RequestOpt { + return func(r *Request) error { + r.config.Network.Proxy = proxy + return nil + } +} + +// WithDialFunc 设置自定义 Dial 函数 +func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt { + return func(r *Request) error { + r.config.Network.DialFunc = fn + return nil + } +} + +// WithTLSConfig 设置 TLS 配置 +func WithTLSConfig(tlsConfig *tls.Config) RequestOpt { + return func(r *Request) error { + r.config.TLS.Config = tlsConfig + return nil + } +} + +// WithSkipTLSVerify 设置是否跳过 TLS 验证 +func WithSkipTLSVerify(skip bool) RequestOpt { + return func(r *Request) error { + r.config.TLS.SkipVerify = skip + return nil + } +} + +// WithCustomIP 设置自定义 IP +func WithCustomIP(ips []string) RequestOpt { + return func(r *Request) error { + for _, ip := range ips { + if net.ParseIP(ip) == nil { + return wrapError(ErrInvalidIP, "ip: %s", ip) + } + } + r.config.DNS.CustomIP = ips + return nil + } +} + +// WithAddCustomIP 添加自定义 IP +func WithAddCustomIP(ip string) RequestOpt { + return func(r *Request) error { + if net.ParseIP(ip) == nil { + return wrapError(ErrInvalidIP, "ip: %s", ip) + } + r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) + return nil + } +} + +// WithCustomDNS 设置自定义 DNS 服务器 +func WithCustomDNS(dnsServers []string) RequestOpt { + return func(r *Request) error { + for _, dns := range dnsServers { + if net.ParseIP(dns) == nil { + return wrapError(ErrInvalidDNS, "dns: %s", dns) + } + } + r.config.DNS.CustomDNS = dnsServers + return nil + } +} + +// WithAddCustomDNS 添加自定义 DNS 服务器 +func WithAddCustomDNS(dns string) RequestOpt { + return func(r *Request) error { + if net.ParseIP(dns) == nil { + return wrapError(ErrInvalidDNS, "dns: %s", dns) + } + r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) + return nil + } +} + +// WithLookupFunc 设置自定义 DNS 解析函数 +func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { + return func(r *Request) error { + r.config.DNS.LookupFunc = fn + return nil + } +} + +// WithHeader 设置 Header +func WithHeader(key, value string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set(key, value) + return nil + } +} + +// WithHeaders 批量设置 Headers +func WithHeaders(headers map[string]string) RequestOpt { + return func(r *Request) error { + for k, v := range headers { + r.config.Headers.Set(k, v) + } + return nil + } +} + +// WithContentType 设置 Content-Type +func WithContentType(contentType string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("Content-Type", contentType) + return nil + } +} + +// WithUserAgent 设置 User-Agent +func WithUserAgent(userAgent string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("User-Agent", userAgent) + return nil + } +} + +// WithBearerToken 设置 Bearer Token +func WithBearerToken(token string) RequestOpt { + return func(r *Request) error { + r.config.Headers.Set("Authorization", "Bearer "+token) + return nil + } +} + +// WithBasicAuth 设置 Basic 认证 +func WithBasicAuth(username, password string) RequestOpt { + return func(r *Request) error { + r.config.BasicAuth = [2]string{username, password} + return nil + } +} + +// WithCookie 添加 Cookie +func WithCookie(name, value, path string) RequestOpt { + return func(r *Request) error { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: path, + }) + return nil + } +} + +// WithSimpleCookie 添加简单 Cookie(path 为 /) +func WithSimpleCookie(name, value string) RequestOpt { + return func(r *Request) error { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) + return nil + } +} + +// WithCookies 批量添加 Cookies +func WithCookies(cookies map[string]string) RequestOpt { + return func(r *Request) error { + for name, value := range cookies { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) + } + return nil + } +} + +// WithBody 设置请求体(字节) +func WithBody(body []byte) RequestOpt { + return func(r *Request) error { + r.config.Body.Bytes = body + r.config.Body.Reader = nil + return nil + } +} + +// WithBodyString 设置请求体(字符串) +func WithBodyString(body string) RequestOpt { + return func(r *Request) error { + r.config.Body.Bytes = []byte(body) + r.config.Body.Reader = nil + return nil + } +} + +// WithBodyReader 设置请求体(Reader) +func WithBodyReader(reader io.Reader) RequestOpt { + return func(r *Request) error { + r.config.Body.Reader = reader + r.config.Body.Bytes = nil + return nil + } +} + +// WithJSON 设置 JSON 请求体 +func WithJSON(v interface{}) RequestOpt { + return func(r *Request) error { + data, err := json.Marshal(v) + if err != nil { + return wrapError(err, "marshal json") + } + r.config.Headers.Set("Content-Type", ContentTypeJSON) + r.config.Body.Bytes = data + r.config.Body.Reader = nil + return nil + } +} + +// WithFormData 设置表单数据 +func WithFormData(data map[string][]string) RequestOpt { + return func(r *Request) error { + r.config.Body.FormData = data + return nil + } +} + +// WithFormDataMap 设置表单数据(简化版) +func WithFormDataMap(data map[string]string) RequestOpt { + return func(r *Request) error { + for k, v := range data { + r.config.Body.FormData[k] = []string{v} + } + return nil + } +} + +// WithAddFormData 添加表单数据 +func WithAddFormData(key, value string) RequestOpt { + return func(r *Request) error { + r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) + return nil + } +} + +// WithFile 添加文件 +func WithFile(formName, filePath string) RequestOpt { + return func(r *Request) error { + stat, err := os.Stat(filePath) + if err != nil { + return wrapError(ErrFileNotFound, "file: %s", filePath) + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FilePath: filePath, + FileSize: stat.Size(), + FileType: ContentTypeOctetStream, + }) + + return nil + } +} + +// WithFileStream 添加文件流 +func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt { + return func(r *Request) error { + if reader == nil { + return ErrNilReader + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: fileName, + FileData: reader, + FileSize: size, + FileType: ContentTypeOctetStream, + }) + + return nil + } +} + +// WithQuery 添加查询参数 +func WithQuery(key, value string) RequestOpt { + return func(r *Request) error { + r.config.Queries[key] = append(r.config.Queries[key], value) + return nil + } +} + +// WithQueries 批量添加查询参数 +func WithQueries(queries map[string]string) RequestOpt { + return func(r *Request) error { + for k, v := range queries { + r.config.Queries[k] = append(r.config.Queries[k], v) + } + return nil + } +} + +// WithContentLength 设置 Content-Length +func WithContentLength(length int64) RequestOpt { + return func(r *Request) error { + r.config.ContentLength = length + return nil + } +} + +// WithAutoCalcContentLength 设置是否自动计算 Content-Length +func WithAutoCalcContentLength(auto bool) RequestOpt { + return func(r *Request) error { + r.config.AutoCalcContentLength = auto + return nil + } +} + +// WithUploadProgress 设置文件上传进度回调 +func WithUploadProgress(fn UploadProgressFunc) RequestOpt { + return func(r *Request) error { + r.config.UploadProgress = fn + return nil + } +} + +// WithTransport 设置自定义 Transport +func WithTransport(transport *http.Transport) RequestOpt { + return func(r *Request) error { + r.config.Transport = transport + r.config.CustomTransport = true + return nil + } +} + +// WithAutoFetch 设置是否自动获取响应体 +func WithAutoFetch(auto bool) RequestOpt { + return func(r *Request) error { + r.autoFetch = auto + return nil + } +} + +// WithRawRequest 设置原始请求 +func WithRawRequest(httpReq *http.Request) RequestOpt { + return func(r *Request) error { + r.httpReq = httpReq + r.doRaw = true + return nil + } +} + +// WithContext 设置 context +func WithContext(ctx context.Context) RequestOpt { + return func(r *Request) error { + r.ctx = ctx + r.httpReq = r.httpReq.WithContext(ctx) + return nil + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..4f2da2b --- /dev/null +++ b/options_test.go @@ -0,0 +1,234 @@ +package starnet + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestWithJSONOpt(t *testing.T) { + type payload struct { + Name string `json:"name"` + Age int `json:"age"` + } + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON { + t.Fatalf("content-type=%s", ct) + } + var p payload + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + t.Fatalf("decode err: %v", err) + } + if p.Name != "alice" || p.Age != 18 { + t.Fatalf("payload mismatch: %+v", p) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18})) + if err != nil { + t.Fatalf("Post error: %v", err) + } + resp.Close() +} + +func TestWithFileOpt(t *testing.T) { + // temp file + cleanup + f, err := os.CreateTemp("", "starnet-upload-*.txt") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + _, _ = f.WriteString("hello-file") + _ = f.Close() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(10 << 20); err != nil { + t.Fatalf("parse form err: %v", err) + } + file, header, err := r.FormFile("file") + if err != nil { + t.Fatalf("form file err: %v", err) + } + defer file.Close() + b, _ := io.ReadAll(file) + if header.Filename == "" || string(b) != "hello-file" { + t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b)) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Post(s.URL, WithFile("file", f.Name())) + if err != nil { + t.Fatalf("Post error: %v", err) + } + resp.Close() +} + +func TestWithFileStreamOpt(t *testing.T) { + content := "stream-content" + reader := strings.NewReader(content) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(10 << 20); err != nil { + t.Fatalf("parse form err: %v", err) + } + file, header, err := r.FormFile("up") + if err != nil { + t.Fatalf("form file err: %v", err) + } + defer file.Close() + b, _ := io.ReadAll(file) + if header.Filename != "a.txt" || string(b) != content { + t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b)) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader)) + if err != nil { + t.Fatalf("Post error: %v", err) + } + resp.Close() +} + +func TestWithQueryOpt(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("k") != "v" { + t.Fatalf("query mismatch: %v", r.URL.Query()) + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Get(s.URL, WithQuery("k", "v")) + if err != nil { + t.Fatalf("Get error: %v", err) + } + resp.Close() +} + +func TestWithUploadProgressOpt(t *testing.T) { + var called int32 + var last int64 + content := strings.Repeat("x", 4096) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseMultipartForm(10 << 20) + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Post(s.URL, + WithUploadProgress(func(filename string, uploaded, total int64) { + atomic.StoreInt32(&called, 1) + last = uploaded + }), + WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)), + ) + if err != nil { + t.Fatalf("Post error: %v", err) + } + resp.Close() + + if atomic.LoadInt32(&called) == 0 { + t.Fatal("progress not called") + } + if last != int64(len(content)) { + t.Fatalf("last uploaded=%d want=%d", last, len(content)) + } +} + +func TestWithTransportOpt(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + resp, err := Get(s.URL, WithTransport(&http.Transport{})) + if err != nil { + t.Fatalf("Get error: %v", err) + } + resp.Close() +} + +func TestWithContextOpt(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err := Get(s.URL, WithContext(ctx)) + if err == nil { + t.Fatal("expected context timeout error") + } +} + +func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"})) + if req.Err() != nil { + t.Fatalf("unexpected err: %v", req.Err()) + } + if len(req.config.DNS.CustomDNS) != 2 { + t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS)) + } +} + +func TestWithAddCustomIPOpt(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4")) + if req.Err() != nil { + t.Fatalf("unexpected err: %v", req.Err()) + } + if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" { + t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP) + } +} + +func TestWithCustomIPOpt(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"})) + if req.Err() != nil { + t.Fatalf("unexpected err: %v", req.Err()) + } + if len(req.config.DNS.CustomIP) != 2 { + t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP)) + } +} + +func TestWithDialFuncOpt(t *testing.T) { + called := int32(0) + fn := func(ctx context.Context, network, addr string) (net.Conn, error) { + atomic.StoreInt32(&called, 1) + return nil, io.EOF + } + req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn)) + if req.config.Network.DialFunc == nil { + t.Fatal("dial func not set") + } + _, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1") + if atomic.LoadInt32(&called) == 0 { + t.Fatal("dial func not called") + } +} + +func TestWithDialTimeoutOpt(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond)) + if req.config.Network.DialTimeout != 123*time.Millisecond { + t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout) + } +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..ee4cf22 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,50 @@ +package starnet + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestProxy(t *testing.T) { + // Create a proxy server + proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Proxy received the request + w.Header().Set("X-Proxied", "true") + w.WriteHeader(http.StatusOK) + w.Write([]byte("proxied")) + })) + defer proxyServer.Close() + + // Note: This is a simplified test. Real proxy testing requires more setup + req := NewSimpleRequest("http://example.com", "GET"). + SetProxy(proxyServer.URL) + + // Just verify the proxy is set in config + if req.config.Network.Proxy != proxyServer.URL { + t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL) + } +} + +func TestClientLevelProxy(t *testing.T) { + proxyURL := "http://proxy.example.com:8080" + client := NewClientNoErr(WithProxy(proxyURL)) + + req, _ := client.NewRequest("http://example.com", "GET") + if req.config.Network.Proxy != proxyURL { + t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL) + } +} + +func TestRequestLevelProxyOverride(t *testing.T) { + clientProxy := "http://client-proxy.com:8080" + requestProxy := "http://request-proxy.com:8080" + + client := NewClientNoErr(WithProxy(clientProxy)) + req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy)) + + // Request level should override client level + if req.config.Network.Proxy != requestProxy { + t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy) + } +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..20ca298 --- /dev/null +++ b/query_test.go @@ -0,0 +1,98 @@ +package starnet + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestQuery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + result := make(map[string][]string) + for k, v := range query { + result[k] = v + } + json.NewEncoder(w).Encode(result) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + AddQuery("name", "John"). + AddQuery("age", "30"). + AddQuery("tags", "go"). + AddQuery("tags", "http") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var result map[string][]string + resp.Body().JSON(&result) + + if len(result["name"]) != 1 || result["name"][0] != "John" { + t.Errorf("name = %v; want [John]", result["name"]) + } + if len(result["tags"]) != 2 { + t.Errorf("tags length = %v; want 2", len(result["tags"])) + } +} + +func TestRequestSetQuery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + w.Write([]byte(query.Get("key"))) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetQuery("key", "value1"). + SetQuery("key", "value2") // Should overwrite + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "value2" { + t.Errorf("query value = %v; want value2", body) + } +} + +func TestRequestDeleteQuery(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + result := make(map[string]string) + for k := range query { + result[k] = query.Get(k) + } + json.NewEncoder(w).Encode(result) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + AddQuery("keep", "yes"). + AddQuery("delete", "no"). + DeleteQuery("delete") + + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + var result map[string]string + resp.Body().JSON(&result) + + if _, exists := result["delete"]; exists { + t.Error("delete query should not exist") + } + if result["keep"] != "yes" { + t.Errorf("keep = %v; want yes", result["keep"]) + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..8c48e66 --- /dev/null +++ b/request.go @@ -0,0 +1,332 @@ +package starnet + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" +) + +// Request HTTP 请求 +type Request struct { + ctx context.Context + execCtx context.Context // 执行时的 context(注入了配置) + url string + method string + err error // 累积的错误 + + config *RequestConfig + client *Client + httpClient *http.Client + httpReq *http.Request + + applied bool // 是否已应用配置 + doRaw bool // 是否使用原始请求(不修改) + autoFetch bool // 是否自动获取响应体 +} + +// newRequest 创建新请求(内部使用) +func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) { + if method == "" { + method = http.MethodGet + } + method = strings.ToUpper(method) + + // 创建 http.Request + httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil) + if err != nil { + return nil, wrapError(err, "create http request") + } + + // 初始化配置 + config := &RequestConfig{ + Network: NetworkConfig{ + DialTimeout: DefaultDialTimeout, + Timeout: DefaultTimeout, + }, + Headers: make(http.Header), + Queries: make(map[string][]string), + Body: BodyConfig{ + FormData: make(map[string][]string), + }, + } + + // 设置默认 User-Agent + config.Headers.Set("User-Agent", DefaultUserAgent) + + // POST 请求默认 Content-Type + if method == http.MethodPost { + config.Headers.Set("Content-Type", ContentTypeFormURLEncoded) + } + + req := &Request{ + ctx: ctx, + url: urlStr, + method: method, + config: config, + httpReq: httpReq, + autoFetch: DefaultFetchRespBody, + } + + // 应用选项 + for _, opt := range opts { + if opt != nil { + if err := opt(req); err != nil { + req.err = err + return req, nil // 不返回错误,累积到 req.err + } + } + } + + return req, nil +} + +// NewRequest 创建新请求 +func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { + return newRequest(context.Background(), url, method, opts...) +} + +// NewRequestWithContext 创建新请求(带 context) +func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { + return newRequest(ctx, url, method, opts...) +} + +// NewSimpleRequest 创建新请求(忽略错误,支持链式调用) +func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { + req, err := newRequest(context.Background(), url, method, opts...) + if err != nil { + // 返回一个带错误的请求 + return &Request{ + ctx: context.Background(), + url: url, + method: method, + err: err, + config: &RequestConfig{ + Headers: make(http.Header), + Queries: make(map[string][]string), + Body: BodyConfig{ + FormData: make(map[string][]string), + }, + }, + } + } + return req +} + +// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误) +func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { + req, err := newRequest(ctx, url, method, opts...) + if err != nil { + return &Request{ + ctx: ctx, + url: url, + method: method, + err: err, + config: &RequestConfig{ + Headers: make(http.Header), + Queries: make(map[string][]string), + Body: BodyConfig{ + FormData: make(map[string][]string), + }, + }, + } + } + return req +} + +// Clone 克隆请求 +func (r *Request) Clone() *Request { + cloned := &Request{ + ctx: r.ctx, + url: r.url, + method: r.method, + err: r.err, + config: r.config.Clone(), + client: r.client, + httpClient: r.httpClient, + applied: false, // 重置应用状态 + doRaw: r.doRaw, + autoFetch: r.autoFetch, + } + + // 重新创建 http.Request + if !r.doRaw { + cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil) + } else { + cloned.httpReq = r.httpReq + } + + return cloned +} + +// Err 获取累积的错误 +func (r *Request) Err() error { + return r.err +} + +// Context 获取 context +func (r *Request) Context() context.Context { + return r.ctx +} + +// SetContext 设置 context +func (r *Request) SetContext(ctx context.Context) *Request { + if r.err != nil { + return r + } + r.ctx = ctx + r.httpReq = r.httpReq.WithContext(ctx) + return r +} + +// Method 获取 HTTP 方法 +func (r *Request) Method() string { + return r.method +} + +// SetMethod 设置 HTTP 方法 +func (r *Request) SetMethod(method string) *Request { + if r.err != nil { + return r + } + + method = strings.ToUpper(method) + if !validMethod(method) { + r.err = wrapError(ErrInvalidMethod, "method: %s", method) + return r + } + + r.method = method + r.httpReq.Method = method + return r +} + +// URL 获取 URL +func (r *Request) URL() string { + return r.url +} + +// SetURL 设置 URL +func (r *Request) SetURL(urlStr string) *Request { + if r.err != nil { + return r + } + + if r.doRaw { + r.err = fmt.Errorf("cannot set URL when using raw request") + return r + } + + u, err := url.Parse(urlStr) + if err != nil { + r.err = wrapError(ErrInvalidURL, "url: %s", urlStr) + return r + } + + r.url = urlStr + u.Host = removeEmptyPort(u.Host) + r.httpReq.Host = u.Host + r.httpReq.URL = u + + // 更新 TLS ServerName + if r.config.TLS.Config != nil { + r.config.TLS.Config.ServerName = u.Hostname() + } + + return r +} + +// RawRequest 获取底层 http.Request +func (r *Request) RawRequest() *http.Request { + return r.httpReq +} + +// SetRawRequest 设置底层 http.Request(启用原始模式) +func (r *Request) SetRawRequest(httpReq *http.Request) *Request { + if r.err != nil { + return r + } + r.httpReq = httpReq + r.doRaw = true + return r +} + +// EnableRawMode 启用原始模式(不修改请求) +func (r *Request) EnableRawMode() *Request { + r.doRaw = true + return r +} + +// DisableRawMode 禁用原始模式 +func (r *Request) DisableRawMode() *Request { + r.doRaw = false + return r +} + +// SetAutoFetch 设置是否自动获取响应体 +func (r *Request) SetAutoFetch(auto bool) *Request { + r.autoFetch = auto + return r +} + +// Do 执行请求 +func (r *Request) Do() (*Response, error) { + // 检查累积的错误 + if r.err != nil { + return nil, r.err + } + + // 准备请求 + if err := r.prepare(); err != nil { + return nil, wrapError(err, "prepare request") + } + + // 执行请求 + httpResp, err := r.httpClient.Do(r.httpReq) + if err != nil { + return &Response{ + Response: &http.Response{}, + request: r, + httpClient: r.httpClient, + body: &Body{}, + }, wrapError(err, "do request") + } + + // 创建响应 + resp := &Response{ + Response: httpResp, + request: r, + httpClient: r.httpClient, + body: &Body{ + raw: httpResp.Body, + }, + } + + // 自动获取响应体 + if r.autoFetch { + resp.body.readAll() + } + + return resp, nil +} + +// Get 发送 GET 请求 +func (r *Request) Get() (*Response, error) { + return r.SetMethod(http.MethodGet).Do() +} + +// Post 发送 POST 请求 +func (r *Request) Post() (*Response, error) { + return r.SetMethod(http.MethodPost).Do() +} + +// Put 发送 PUT 请求 +func (r *Request) Put() (*Response, error) { + return r.SetMethod(http.MethodPut).Do() +} + +// Delete 发送 DELETE 请求 +func (r *Request) Delete() (*Response, error) { + return r.SetMethod(http.MethodDelete).Do() +} diff --git a/request_body.go b/request_body.go new file mode 100644 index 0000000..fdba83d --- /dev/null +++ b/request_body.go @@ -0,0 +1,463 @@ +package starnet + +import ( + "bytes" + "encoding/json" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "strings" +) + +// SetBody 设置请求体(字节) +func (r *Request) SetBody(body []byte) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Body.Bytes = body + r.config.Body.Reader = nil + return r +} + +// SetBodyReader 设置请求体(Reader) +func (r *Request) SetBodyReader(reader io.Reader) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Body.Reader = reader + r.config.Body.Bytes = nil + return r +} + +// SetBodyString 设置请求体(字符串) +func (r *Request) SetBodyString(body string) *Request { + return r.SetBody([]byte(body)) +} + +// SetJSON 设置 JSON 请求体 +func (r *Request) SetJSON(v interface{}) *Request { + if r.err != nil { + return r + } + + data, err := json.Marshal(v) + if err != nil { + r.err = wrapError(err, "marshal json") + return r + } + + return r.SetContentType(ContentTypeJSON).SetBody(data) +} + +// SetFormData 设置表单数据(覆盖) +func (r *Request) SetFormData(data map[string][]string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Body.FormData = data + return r +} + +// AddFormData 添加表单数据 +func (r *Request) AddFormData(key, value string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) + return r +} + +// AddFormDataMap 批量添加表单数据 +func (r *Request) AddFormDataMap(data map[string]string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + for k, v := range data { + r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v) + } + return r +} + +// AddFile 添加文件(从路径) +func (r *Request) AddFile(formName, filePath string) *Request { + if r.err != nil { + return r + } + + stat, err := os.Stat(filePath) + if err != nil { + r.err = wrapError(ErrFileNotFound, "file: %s", filePath) + return r + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FilePath: filePath, + FileSize: stat.Size(), + FileType: ContentTypeOctetStream, + }) + + return r +} + +// AddFileWithName 添加文件(指定文件名) +func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request { + if r.err != nil { + return r + } + + stat, err := os.Stat(filePath) + if err != nil { + r.err = wrapError(ErrFileNotFound, "file: %s", filePath) + return r + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: fileName, + FilePath: filePath, + FileSize: stat.Size(), + FileType: ContentTypeOctetStream, + }) + + return r +} + +// AddFileWithType 添加文件(指定 MIME 类型) +func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request { + if r.err != nil { + return r + } + + stat, err := os.Stat(filePath) + if err != nil { + r.err = wrapError(ErrFileNotFound, "file: %s", filePath) + return r + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FilePath: filePath, + FileSize: stat.Size(), + FileType: fileType, + }) + + return r +} + +// AddFileStream 添加文件流 +func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request { + if r.err != nil { + return r + } + + if reader == nil { + r.err = ErrNilReader + return r + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: fileName, + FileData: reader, + FileSize: size, + FileType: ContentTypeOctetStream, + }) + + return r +} + +// AddFileStreamWithType 添加文件流(指定 MIME 类型) +func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request { + if r.err != nil { + return r + } + + if reader == nil { + r.err = ErrNilReader + return r + } + + r.config.Body.Files = append(r.config.Body.Files, RequestFile{ + FormName: formName, + FileName: fileName, + FileData: reader, + FileSize: size, + FileType: fileType, + }) + + return r +} + +// applyBody 应用请求体 +func (r *Request) applyBody() error { + // 优先级:Reader > Bytes > Files > FormData + + // 1. Reader + if r.config.Body.Reader != nil { + r.httpReq.Body = io.NopCloser(r.config.Body.Reader) + + // 尝试获取长度 + switch v := r.config.Body.Reader.(type) { + case *bytes.Buffer: + r.httpReq.ContentLength = int64(v.Len()) + case *bytes.Reader: + r.httpReq.ContentLength = int64(v.Len()) + case *strings.Reader: + r.httpReq.ContentLength = int64(v.Len()) + } + + return nil + } + + // 2. Bytes + if len(r.config.Body.Bytes) > 0 { + r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes)) + r.httpReq.ContentLength = int64(len(r.config.Body.Bytes)) + return nil + } + + // 3. Files(multipart/form-data) + if len(r.config.Body.Files) > 0 { + return r.applyMultipartBody() + } + + // 4. FormData(application/x-www-form-urlencoded) + if len(r.config.Body.FormData) > 0 { + values := url.Values{} + for k, vs := range r.config.Body.FormData { + for _, v := range vs { + values.Add(k, v) + } + } + encoded := values.Encode() + r.httpReq.Body = io.NopCloser(strings.NewReader(encoded)) + r.httpReq.ContentLength = int64(len(encoded)) + return nil + } + + return nil +} + +// applyMultipartBody 应用 multipart 请求体 +func (r *Request) applyMultipartBody() error { + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + + // 设置 Content-Type + r.httpReq.Header.Set("Content-Type", writer.FormDataContentType()) + r.httpReq.Body = pr + + // 在 goroutine 中写入数据 + go func() { + defer pw.Close() + defer writer.Close() + + // 写入表单字段 + for k, vs := range r.config.Body.FormData { + for _, v := range vs { + if err := writer.WriteField(k, v); err != nil { + pw.CloseWithError(wrapError(err, "write form field")) + return + } + } + } + + // 写入文件 + for _, file := range r.config.Body.Files { + if err := r.writeFile(writer, file); err != nil { + pw.CloseWithError(err) + return + } + } + }() + + return nil +} + +// writeFile 写入文件到 multipart writer +func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error { + // 创建文件字段 + part, err := writer.CreateFormFile(file.FormName, file.FileName) + if err != nil { + return wrapError(err, "create form file") + } + + // 获取文件数据源 + var reader io.Reader + if file.FileData != nil { + reader = file.FileData + } else if file.FilePath != "" { + f, err := os.Open(file.FilePath) + if err != nil { + return wrapError(err, "open file") + } + defer f.Close() + reader = f + } else { + return ErrNilReader + } + + // 复制文件数据(带进度) + if r.config.UploadProgress != nil { + _, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress) + } else { + _, err = io.Copy(part, reader) + } + + if err != nil { + return wrapError(err, "copy file data") + } + + return nil +} + +// prepare 准备请求(应用配置) +func (r *Request) prepare() error { + if r.applied { + return nil + } + defer func() { r.applied = true }() + // 即使 raw 模式也要确保有 httpClient + if r.httpClient == nil { + var err error + r.httpClient, err = r.buildHTTPClient() + if err != nil { + return err + } + } + // 原始模式不修改请求内容 + if r.doRaw { + return nil + } + + // 应用查询参数 + if len(r.config.Queries) > 0 { + q := r.httpReq.URL.Query() + for k, values := range r.config.Queries { + for _, v := range values { + q.Add(k, v) + } + } + r.httpReq.URL.RawQuery = q.Encode() + } + + // 应用 Headers + for k, values := range r.config.Headers { + for _, v := range values { + r.httpReq.Header.Add(k, v) + } + } + + // 应用 Cookies + for _, cookie := range r.config.Cookies { + r.httpReq.AddCookie(cookie) + } + + // 应用 Basic Auth + if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { + r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) + } + + // 应用请求体 + if err := r.applyBody(); err != nil { + return err + } + + // 应用 Content-Length + if r.config.ContentLength > 0 { + r.httpReq.ContentLength = r.config.ContentLength + } else if r.config.ContentLength < 0 { + r.httpReq.ContentLength = 0 + } + + // 自动计算 Content-Length + if r.config.AutoCalcContentLength && r.httpReq.Body != nil { + data, err := io.ReadAll(r.httpReq.Body) + if err != nil { + return wrapError(err, "read body for content length") + } + r.httpReq.ContentLength = int64(len(data)) + r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data)) + } + + // 设置 TLS ServerName(如果有 TLS Config) + if r.config.TLS.Config != nil && r.httpReq.URL != nil { + r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() + } + + // 注入配置到 context + r.execCtx = injectRequestConfig(r.ctx, r.config) + r.httpReq = r.httpReq.WithContext(r.execCtx) + + return nil +} + +// buildHTTPClient 构建 HTTP Client +func (r *Request) buildHTTPClient() (*http.Client, error) { + applyTimeoutOverride := func(base *http.Client) *http.Client { + // 没有 base 时兜底 + if base == nil { + base = &http.Client{} + } + + rt := r.config.Network.Timeout + + // 语义: + // rt < 0 : 本次请求禁用超时(Timeout = 0) + // rt = 0 : 沿用 base.Timeout + // rt > 0 : 本次请求超时覆盖 + if rt == 0 { + return base + } + + clone := &http.Client{ + Transport: base.Transport, + CheckRedirect: base.CheckRedirect, + Jar: base.Jar, + } + + if rt < 0 { + clone.Timeout = 0 + } else { + clone.Timeout = rt + } + return clone + } + + // 优先使用请求关联的 Client + if r.client != nil { + return applyTimeoutOverride(r.client.HTTPClient()), nil + } + + // 自定义 Transport + if r.config.CustomTransport && r.config.Transport != nil { + base := &http.Client{ + Transport: &Transport{base: r.config.Transport}, + Timeout: 0, + } + return applyTimeoutOverride(base), nil + } + + // 默认全局 client + return applyTimeoutOverride(DefaultHTTPClient()), nil +} diff --git a/request_config.go b/request_config.go new file mode 100644 index 0000000..fed3fe0 --- /dev/null +++ b/request_config.go @@ -0,0 +1,269 @@ +package starnet + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "time" +) + +// SetTimeout 设置请求总超时时间 +// timeout > 0: 使用该超时 +// timeout = 0: 使用 Client 默认超时 +// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0) +func (r *Request) SetTimeout(timeout time.Duration) *Request { + if r.err != nil { + return r + } + r.config.Network.Timeout = timeout + return r +} + +// SetDialTimeout 设置连接超时时间 +func (r *Request) SetDialTimeout(timeout time.Duration) *Request { + if r.err != nil { + return r + } + r.config.Network.DialTimeout = timeout + return r +} + +// SetProxy 设置代理 +func (r *Request) SetProxy(proxy string) *Request { + if r.err != nil { + return r + } + r.config.Network.Proxy = proxy + return r +} + +// SetDialFunc 设置自定义 Dial 函数 +func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request { + if r.err != nil { + return r + } + r.config.Network.DialFunc = fn + return r +} + +// SetTLSConfig 设置 TLS 配置 +func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request { + if r.err != nil { + return r + } + r.config.TLS.Config = tlsConfig + return r +} + +// SetSkipTLSVerify 设置是否跳过 TLS 验证 +func (r *Request) SetSkipTLSVerify(skip bool) *Request { + if r.err != nil { + return r + } + r.config.TLS.SkipVerify = skip + return r +} + +// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS) +func (r *Request) SetCustomIP(ips []string) *Request { + if r.err != nil { + return r + } + + // 验证 IP 格式 + for _, ip := range ips { + if net.ParseIP(ip) == nil { + r.err = wrapError(ErrInvalidIP, "ip: %s", ip) + return r + } + } + + r.config.DNS.CustomIP = ips + return r +} + +// AddCustomIP 添加自定义 IP +func (r *Request) AddCustomIP(ip string) *Request { + if r.err != nil { + return r + } + + if net.ParseIP(ip) == nil { + r.err = wrapError(ErrInvalidIP, "ip: %s", ip) + return r + } + + r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) + return r +} + +// SetCustomDNS 设置自定义 DNS 服务器 +func (r *Request) SetCustomDNS(dnsServers []string) *Request { + if r.err != nil { + return r + } + + // 验证 DNS 服务器格式 + for _, dns := range dnsServers { + if net.ParseIP(dns) == nil { + r.err = wrapError(ErrInvalidDNS, "dns: %s", dns) + return r + } + } + + r.config.DNS.CustomDNS = dnsServers + return r +} + +// AddCustomDNS 添加自定义 DNS 服务器 +func (r *Request) AddCustomDNS(dns string) *Request { + if r.err != nil { + return r + } + + if net.ParseIP(dns) == nil { + r.err = wrapError(ErrInvalidDNS, "dns: %s", dns) + return r + } + + r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) + return r +} + +// SetLookupFunc 设置自定义 DNS 解析函数 +func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request { + if r.err != nil { + return r + } + r.config.DNS.LookupFunc = fn + return r +} + +// SetBasicAuth 设置 Basic 认证 +func (r *Request) SetBasicAuth(username, password string) *Request { + if r.err != nil { + return r + } + r.config.BasicAuth = [2]string{username, password} + return r +} + +// SetContentLength 设置 Content-Length +func (r *Request) SetContentLength(length int64) *Request { + if r.err != nil { + return r + } + r.config.ContentLength = length + return r +} + +// SetAutoCalcContentLength 设置是否自动计算 Content-Length +// 警告:启用后会将整个 body 读入内存 +func (r *Request) SetAutoCalcContentLength(auto bool) *Request { + if r.err != nil { + return r + } + + if r.doRaw { + r.err = fmt.Errorf("cannot set auto calc content length in raw mode") + return r + } + + r.config.AutoCalcContentLength = auto + return r +} + +// SetTransport 设置自定义 Transport +func (r *Request) SetTransport(transport *http.Transport) *Request { + if r.err != nil { + return r + } + r.config.Transport = transport + r.config.CustomTransport = true + return r +} + +// SetUploadProgress 设置文件上传进度回调 +func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request { + if r.err != nil { + return r + } + r.config.UploadProgress = fn + return r +} + +// AddQuery 添加查询参数 +func (r *Request) AddQuery(key, value string) *Request { + if r.err != nil { + return r + } + r.config.Queries[key] = append(r.config.Queries[key], value) + return r +} + +// SetQuery 设置查询参数(覆盖) +func (r *Request) SetQuery(key, value string) *Request { + if r.err != nil { + return r + } + r.config.Queries[key] = []string{value} + return r +} + +// SetQueries 设置所有查询参数(覆盖) +func (r *Request) SetQueries(queries map[string][]string) *Request { + if r.err != nil { + return r + } + r.config.Queries = queries + return r +} + +// AddQueries 批量添加查询参数 +func (r *Request) AddQueries(queries map[string]string) *Request { + if r.err != nil { + return r + } + for k, v := range queries { + r.config.Queries[k] = append(r.config.Queries[k], v) + } + return r +} + +// DeleteQuery 删除查询参数 +func (r *Request) DeleteQuery(key string) *Request { + if r.err != nil { + return r + } + delete(r.config.Queries, key) + return r +} + +// DeleteQueryValue 删除查询参数的特定值 +func (r *Request) DeleteQueryValue(key, value string) *Request { + if r.err != nil { + return r + } + + values, ok := r.config.Queries[key] + if !ok { + return r + } + + newValues := make([]string, 0, len(values)) + for _, v := range values { + if v != value { + newValues = append(newValues, v) + } + } + + if len(newValues) == 0 { + delete(r.config.Queries, key) + } else { + r.config.Queries[key] = newValues + } + + return r +} diff --git a/request_header.go b/request_header.go new file mode 100644 index 0000000..55f48be --- /dev/null +++ b/request_header.go @@ -0,0 +1,180 @@ +package starnet + +import ( + "net/http" +) + +// SetHeader 设置 Header(覆盖) +func (r *Request) SetHeader(key, value string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Headers.Set(key, value) + return r +} + +// AddHeader 添加 Header +func (r *Request) AddHeader(key, value string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Headers.Add(key, value) + return r +} + +// SetHeaders 设置所有 Headers(覆盖) +func (r *Request) SetHeaders(headers http.Header) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Headers = headers + return r +} + +// AddHeaders 批量添加 Headers +func (r *Request) AddHeaders(headers map[string]string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + for k, v := range headers { + r.config.Headers.Add(k, v) + } + return r +} + +// DeleteHeader 删除 Header +func (r *Request) DeleteHeader(key string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Headers.Del(key) + return r +} + +// GetHeader 获取 Header +func (r *Request) GetHeader(key string) string { + return r.config.Headers.Get(key) +} + +// Headers 获取所有 Headers +func (r *Request) Headers() http.Header { + return r.config.Headers +} + +// SetContentType 设置 Content-Type +func (r *Request) SetContentType(contentType string) *Request { + return r.SetHeader("Content-Type", contentType) +} + +// SetUserAgent 设置 User-Agent +func (r *Request) SetUserAgent(userAgent string) *Request { + return r.SetHeader("User-Agent", userAgent) +} + +// SetReferer 设置 Referer +func (r *Request) SetReferer(referer string) *Request { + return r.SetHeader("Referer", referer) +} + +// SetBearerToken 设置 Bearer Token +func (r *Request) SetBearerToken(token string) *Request { + return r.SetHeader("Authorization", "Bearer "+token) +} + +// AddCookie 添加 Cookie +func (r *Request) AddCookie(cookie *http.Cookie) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Cookies = append(r.config.Cookies, cookie) + return r +} + +// AddSimpleCookie 添加简单 Cookie(path 为 /) +func (r *Request) AddSimpleCookie(name, value string) *Request { + return r.AddCookie(&http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) +} + +// AddCookieKV 添加 Cookie(指定 path) +func (r *Request) AddCookieKV(name, value, path string) *Request { + return r.AddCookie(&http.Cookie{ + Name: name, + Value: value, + Path: path, + }) +} + +// SetCookies 设置所有 Cookies(覆盖) +func (r *Request) SetCookies(cookies []*http.Cookie) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + r.config.Cookies = cookies + return r +} + +// AddCookies 批量添加 Cookies +func (r *Request) AddCookies(cookies map[string]string) *Request { + if r.err != nil { + return r + } + if r.doRaw { + return r + } + for name, value := range cookies { + r.config.Cookies = append(r.config.Cookies, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) + } + return r +} + +// Cookies 获取所有 Cookies +func (r *Request) Cookies() []*http.Cookie { + return r.config.Cookies +} + +// ResetHeaders 重置所有 Headers +func (r *Request) ResetHeaders() *Request { + if r.err != nil { + return r + } + r.config.Headers = make(http.Header) + return r +} + +// ResetCookies 重置所有 Cookies +func (r *Request) ResetCookies() *Request { + if r.err != nil { + return r + } + r.config.Cookies = []*http.Cookie{} + return r +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..e5019e3 --- /dev/null +++ b/request_test.go @@ -0,0 +1,172 @@ +package starnet + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestNewSimpleRequest(t *testing.T) { + tests := []struct { + name string + url string + method string + expectErr bool + }{ + { + name: "valid GET request", + url: "https://example.com", + method: "GET", + expectErr: false, + }, + { + name: "valid POST request", + url: "https://example.com", + method: "POST", + expectErr: false, + }, + { + name: "invalid URL", + url: "://invalid", + method: "GET", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := NewRequest(tt.url, tt.method) + if tt.expectErr { + if err == nil && req.Err() == nil { + t.Errorf("NewRequest() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("NewRequest() unexpected error: %v", err) + } + if req.Method() != strings.ToUpper(tt.method) { + t.Errorf("Method = %v; want %v", req.Method(), tt.method) + } + } + }) + } +} + +func TestRequestMethods(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Method", r.Method) + w.WriteHeader(http.StatusOK) + w.Write([]byte(r.Method)) + })) + defer server.Close() + + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + req := NewSimpleRequest(server.URL, method) + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } + + if method != "HEAD" { + body, _ := resp.Body().String() + if body != method { + t.Errorf("Body = %v; want %v", body, method) + } + } + }) + } +} + +func TestRequestSetMethod(t *testing.T) { + req := NewSimpleRequest("https://example.com", "GET") + + req.SetMethod("POST") + if req.Method() != "POST" { + t.Errorf("Method = %v; want POST", req.Method()) + } + + req.SetMethod("invalid method!") + if req.Err() == nil { + t.Error("SetMethod with invalid method should set error") + } +} + +func TestRequestSetURL(t *testing.T) { + req := NewSimpleRequest("https://example.com", "GET") + + req.SetURL("https://newexample.com") + if req.URL() != "https://newexample.com" { + t.Errorf("URL = %v; want https://newexample.com", req.URL()) + } + + req2 := NewSimpleRequest("https://example.com", "GET") + req2.SetURL("://invalid") + if req2.Err() == nil { + t.Error("SetURL with invalid URL should set error") + } +} + +func TestRequestClone(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Test") != "value" { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + req := NewSimpleRequest(server.URL, "GET"). + SetHeader("X-Test", "value") + + // 第一次请求 + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + resp.Close() + + // 克隆请求 + cloned := req.Clone() + cloned.SetHeader("X-Extra", "extra") + + // 克隆的请求应该也能成功 + resp2, err := cloned.Do() + if err != nil { + t.Fatalf("Cloned Do() error: %v", err) + } + defer resp2.Close() + + if resp2.StatusCode != http.StatusOK { + t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK) + } +} + +func TestRequestContext(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + req := NewSimpleRequest(server.URL, "GET").SetContext(ctx) + _, err := req.Do() + if err == nil { + t.Error("Expected timeout error, got nil") + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..1de8db4 --- /dev/null +++ b/response.go @@ -0,0 +1,161 @@ +package starnet + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "sync" +) + +// Response HTTP 响应 +type Response struct { + *http.Response + request *Request + httpClient *http.Client + body *Body +} + +// Body 响应体 +type Body struct { + raw io.ReadCloser + data []byte + consumed bool + mu sync.Mutex +} + +// Request 获取原始请求 +func (r *Response) Request() *Request { + return r.request +} + +// Body 获取响应体 +func (r *Response) Body() *Body { + return r.body +} + +// Close 关闭响应体 +func (r *Response) Close() error { + if r.body != nil && r.body.raw != nil { + return r.body.raw.Close() + } + return nil +} + +// CloseWithClient 关闭响应体并关闭空闲连接 +func (r *Response) CloseWithClient() error { + if r.httpClient != nil { + r.httpClient.CloseIdleConnections() + } + return r.Close() +} + +// readAll 读取所有数据 +func (b *Body) readAll() error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.consumed { + return nil + } + + if b.raw == nil { + b.consumed = true + return nil + } + + data, err := io.ReadAll(b.raw) + if err != nil { + return wrapError(err, "read response body") + } + + b.data = data + b.consumed = true + b.raw.Close() + + return nil +} + +// Bytes 获取响应体字节 +func (b *Body) Bytes() ([]byte, error) { + if err := b.readAll(); err != nil { + return nil, err + } + return b.data, nil +} + +// String 获取响应体字符串 +func (b *Body) String() (string, error) { + data, err := b.Bytes() + if err != nil { + return "", err + } + return string(data), nil +} + +// JSON 解析 JSON 响应 +func (b *Body) JSON(v interface{}) error { + data, err := b.Bytes() + if err != nil { + return err + } + return json.Unmarshal(data, v) +} + +// Reader 获取 Reader(只能调用一次) +func (b *Body) Reader() (io.ReadCloser, error) { + b.mu.Lock() + defer b.mu.Unlock() + + if b.consumed { + if b.data != nil { + // 已读取,返回缓存数据的 Reader + return io.NopCloser(bytes.NewReader(b.data)), nil + } + return nil, ErrBodyAlreadyConsumed + } + + b.consumed = true + return b.raw, nil +} + +// IsConsumed 检查是否已消费 +func (b *Body) IsConsumed() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.consumed +} + +// Close 关闭 Body +func (b *Body) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.raw != nil { + return b.raw.Close() + } + return nil +} + +// MustBytes 获取响应体字节(忽略错误,失败返回 nil) +func (b *Body) MustBytes() []byte { + data, err := b.Bytes() + if err != nil { + return nil + } + return data +} + +// MustString 获取响应体字符串(忽略错误,失败返回空串) +func (b *Body) MustString() string { + s, err := b.String() + if err != nil { + return "" + } + return s +} + +// Unmarshal 解析 JSON 响应(兼容旧 API) +func (b *Body) Unmarshal(v interface{}) error { + return b.JSON(v) +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..a82a542 --- /dev/null +++ b/response_test.go @@ -0,0 +1,179 @@ +package starnet + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestResponseBody(t *testing.T) { + testData := "test response data" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(testData)) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + // Test String() + body, err := resp.Body().String() + if err != nil { + t.Fatalf("Body().String() error: %v", err) + } + if body != testData { + t.Errorf("Body = %v; want %v", body, testData) + } + + // Test multiple reads (should work because body is cached) + body2, err := resp.Body().String() + if err != nil { + t.Fatalf("Second Body().String() error: %v", err) + } + if body2 != testData { + t.Errorf("Second Body = %v; want %v", body2, testData) + } +} + +func TestResponseJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"name":"John","age":30}`)) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + var result struct { + Name string `json:"name"` + Age int `json:"age"` + } + + err = resp.Body().JSON(&result) + if err != nil { + t.Fatalf("Body().JSON() error: %v", err) + } + + if result.Name != "John" { + t.Errorf("Name = %v; want John", result.Name) + } + if result.Age != 30 { + t.Errorf("Age = %v; want 30", result.Age) + } +} + +func TestResponseBytes(t *testing.T) { + testData := []byte("binary data") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(testData) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + body, err := resp.Body().Bytes() + if err != nil { + t.Fatalf("Body().Bytes() error: %v", err) + } + + if string(body) != string(testData) { + t.Errorf("Body = %v; want %v", body, testData) + } +} + +func TestResponseReader(t *testing.T) { + testData := "stream data" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(testData)) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + reader, err := resp.Body().Reader() + if err != nil { + t.Fatalf("Body().Reader() error: %v", err) + } + defer reader.Close() + + body, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll() error: %v", err) + } + + if string(body) != testData { + t.Errorf("Body = %v; want %v", string(body), testData) + } +} + +func TestResponseAutoFetch(t *testing.T) { + testData := "auto fetch data" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(testData)) + })) + defer server.Close() + + // With auto fetch + resp, err := Get(server.URL, WithAutoFetch(true)) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if !resp.Body().IsConsumed() { + t.Error("Body should be consumed with auto fetch") + } + + body, _ := resp.Body().String() + if body != testData { + t.Errorf("Body = %v; want %v", body, testData) + } +} + +func TestResponseStatusCode(t *testing.T) { + tests := []struct { + name string + statusCode int + }{ + {"OK", http.StatusOK}, + {"Created", http.StatusCreated}, + {"BadRequest", http.StatusBadRequest}, + {"NotFound", http.StatusNotFound}, + {"InternalServerError", http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != tt.statusCode { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode) + } + }) + } +} diff --git a/timeout_test.go b/timeout_test.go new file mode 100644 index 0000000..4c52594 --- /dev/null +++ b/timeout_test.go @@ -0,0 +1,66 @@ +package starnet + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRequestTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Should timeout + req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond) + _, err := req.Do() + if err == nil { + t.Error("Expected timeout error, got nil") + } + + // Should succeed + req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond) + resp, err := req2.Do() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if resp != nil { + resp.Close() + } +} + +func TestRequestDialTimeout(t *testing.T) { + // Use a non-routable IP to test dial timeout + req := NewSimpleRequest("http://192.0.2.1:80", "GET"). + SetDialTimeout(100 * time.Millisecond) + + start := time.Now() + _, err := req.Do() + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected dial timeout error, got nil") + } + + // Should timeout within reasonable time (not wait forever) + if elapsed > 2*time.Second { + t.Errorf("Dial timeout took too long: %v", elapsed) + } +} + +func TestClientTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClientNoErr(WithTimeout(100 * time.Millisecond)) + _, err := client.Get(server.URL) + if err == nil { + t.Error("Expected timeout error, got nil") + } +} diff --git a/tls_test.go b/tls_test.go new file mode 100644 index 0000000..220b6db --- /dev/null +++ b/tls_test.go @@ -0,0 +1,102 @@ +package starnet + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestSkipTLSVerify(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Without skip verify (should fail) + req := NewSimpleRequest(server.URL, "GET") + _, err := req.Do() + if err == nil { + t.Error("Expected TLS error without skip verify, got nil") + } + + // With skip verify (should succeed) + req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true) + resp, err := req2.Do() + if err != nil { + t.Fatalf("Do() with skip verify error: %v", err) + } + defer resp.Close() + + body, _ := resp.Body().String() + if body != "OK" { + t.Errorf("Body = %v; want OK", body) + } +} + +func TestRequestCustomTLSConfig(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + } + + req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig) + resp, err := req.Do() + if err != nil { + t.Fatalf("Do() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } +} + +func TestClientDefaultTLSConfig(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClientNoErr() + client.SetDefaultSkipTLSVerify(true) + + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } +} + +func TestRequestLevelTLSOverride(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Client level: skip verify = false + client := NewClientNoErr() + client.SetDefaultSkipTLSVerify(false) + + // Request level: skip verify = true (should override) + resp, err := client.Get(server.URL, WithSkipTLSVerify(true)) + if err != nil { + t.Fatalf("Get() error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) + } +} diff --git a/tlsconfig.go b/tlsconfig.go new file mode 100644 index 0000000..5d18e09 --- /dev/null +++ b/tlsconfig.go @@ -0,0 +1,55 @@ +package starnet + +import ( + "crypto/tls" + "net" + "time" +) + +// GetConfigForClientFunc selects TLS config by hostname/SNI. +type GetConfigForClientFunc func(hostname string) (*tls.Config, error) + +// ListenerConfig controls listener behavior. +type ListenerConfig struct { + // BaseTLSConfig is used for TLS when dynamic selection returns nil. + BaseTLSConfig *tls.Config + + // GetConfigForClient selects TLS config for a hostname. + GetConfigForClient GetConfigForClientFunc + + // AllowNonTLS allows plain TCP fallback. + AllowNonTLS bool + + // SniffTimeout bounds protocol sniffing time. 0 means no timeout. + SniffTimeout time.Duration + + // MaxClientHelloBytes limits buffered sniff data. + // If <= 0, default 64KiB. + MaxClientHelloBytes int + + // Logger is optional. + Logger Logger +} + +// DefaultListenerConfig returns a conservative default config. +func DefaultListenerConfig() ListenerConfig { + return ListenerConfig{ + AllowNonTLS: false, + SniffTimeout: 5 * time.Second, + MaxClientHelloBytes: 64 * 1024, + } +} + +// TLSDefaults returns a TLS config baseline. +// Caller should set Certificates / GetCertificate as needed. +func TLSDefaults() *tls.Config { + return &tls.Config{ + MinVersion: tls.VersionTLS12, + } +} + +// DialConfig controls dialing behavior. +type DialConfig struct { + Timeout time.Duration + LocalAddr net.Addr +} diff --git a/tlssniffer.go b/tlssniffer.go index ebc02ea..16497e5 100644 --- a/tlssniffer.go +++ b/tlssniffer.go @@ -10,269 +10,231 @@ import ( "time" ) -type myConn struct { - reader io.Reader - conn net.Conn - isReadOnly bool - multiReader io.Reader +// replayConn replays buffered bytes first, then reads from live conn. +type replayConn struct { + reader io.Reader + conn net.Conn } -func (c *myConn) Read(p []byte) (int, error) { - if c.isReadOnly { - return c.reader.Read(p) +func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn { + return &replayConn{ + reader: io.MultiReader(buffered, conn), + conn: conn, } - if c.multiReader == nil { - c.multiReader = io.MultiReader(c.reader, c.conn) - } - return c.multiReader.Read(p) } -func (c *myConn) Write(p []byte) (int, error) { - if c.isReadOnly { - return 0, io.ErrClosedPipe - } - return c.conn.Write(p) -} -func (c *myConn) Close() error { - if c.isReadOnly { - return nil - } - return c.conn.Close() -} -func (c *myConn) LocalAddr() net.Addr { - if c.isReadOnly { - return nil - } - return c.conn.LocalAddr() -} -func (c *myConn) RemoteAddr() net.Addr { - if c.isReadOnly { - return nil - } - return c.conn.RemoteAddr() -} -func (c *myConn) SetDeadline(t time.Time) error { - if c.isReadOnly { - return nil - } - return c.conn.SetDeadline(t) -} -func (c *myConn) SetReadDeadline(t time.Time) error { - if c.isReadOnly { - return nil - } - return c.conn.SetReadDeadline(t) -} -func (c *myConn) SetWriteDeadline(t time.Time) error { - if c.isReadOnly { - return nil - } - return c.conn.SetWriteDeadline(t) +func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) } +func (c *replayConn) Close() error { return c.conn.Close() } +func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +// SniffResult describes protocol sniffing result. +type SniffResult struct { + IsTLS bool + Hostname string + Buffer *bytes.Buffer } -type Listener struct { - net.Listener - cfg *tls.Config - getConfigForClient func(hostname string) *tls.Config - allowNonTls bool +// Sniffer detects protocol and metadata from initial bytes. +type Sniffer interface { + Sniff(conn net.Conn, maxBytes int) (SniffResult, error) } -func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config { - return l.getConfigForClient -} +// TLSSniffer is the default sniffer implementation. +type TLSSniffer struct{} -func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) { - l.getConfigForClient = getConfigForClient -} - -func Listen(network, address string) (*Listener, error) { - listener, err := net.Listen(network, address) - if err != nil { - return nil, err - } - return &Listener{Listener: listener}, nil -} - -func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) { - listener, err := liscfg.Listen(context.Background(), network, address) - if err != nil { - return nil, err - } - return &Listener{ - Listener: listener, - cfg: config, - getConfigForClient: getConfigForClient, - allowNonTls: allowNonTls, - }, nil -} - -func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) { - return &Listener{ - Listener: listener, - cfg: config, - getConfigForClient: getConfigForClient, - allowNonTls: allowNonTls, - }, nil -} - -func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) { - listener, err := net.Listen(network, address) - if err != nil { - return nil, err - } - return &Listener{ - Listener: listener, - cfg: config, - getConfigForClient: getConfigForClient, - allowNonTls: allowNonTls, - }, nil -} - -func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) { - config, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err +// Sniff detects TLS and extracts SNI when possible. +func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { + if maxBytes <= 0 { + maxBytes = 64 * 1024 } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{config}, - } + var buf bytes.Buffer + limited := &io.LimitedReader{R: conn, N: int64(maxBytes)} + tee := io.TeeReader(limited, &buf) - listener, err := net.Listen(network, address) - if err != nil { - return nil, err - } - - return &Listener{ - Listener: listener, - cfg: tlsConfig, - allowNonTls: allowNonTls, - }, nil -} - -func (l *Listener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return &Conn{ - Conn: conn, - tlsCfg: l.cfg, - getConfigForClient: l.getConfigForClient, - allowNonTls: l.allowNonTls, - }, nil -} - -type Conn struct { - net.Conn - once sync.Once - initErr error - isTLS bool - tlsCfg *tls.Config - tlsConn *tls.Conn - buffer *bytes.Buffer - noTlsReader io.Reader - isOriginal bool - getConfigForClient func(hostname string) *tls.Config - hostname string - allowNonTls bool -} - -func (c *Conn) Hostname() string { - if c.hostname != "" { - return c.hostname - } - if c.isTLS && c.tlsConn != nil { - if c.tlsConn.ConnectionState().ServerName != "" { - c.hostname = c.tlsConn.ConnectionState().ServerName - return c.hostname - } - } - return "" -} - -func (c *Conn) IsTLS() bool { - return c.isTLS -} - -func (c *Conn) TlsConn() *tls.Conn { - return c.tlsConn -} - -func (c *Conn) isTLSConnection() (bool, error) { - if c.getConfigForClient == nil { - peek := make([]byte, 5) - n, err := io.ReadFull(c.Conn, peek) - if err != nil { - return false, err - } - - isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 - - c.buffer = bytes.NewBuffer(peek[:n]) - return isTLS, nil - } - - c.buffer = new(bytes.Buffer) - r := io.TeeReader(c.Conn, c.buffer) var hello *tls.ClientHelloInfo - tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{ - GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { - hello = new(tls.ClientHelloInfo) - *hello = *argHello + _ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + cp := *ch + hello = &cp return nil, nil }, }).Handshake() - peek := c.buffer.Bytes() - n := len(peek) - isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 - if hello == nil { - return isTLS, nil + + peek := buf.Bytes() + isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03 + + out := SniffResult{ + IsTLS: isTLS, + Buffer: bytes.NewBuffer(append([]byte(nil), peek...)), } - c.hostname = hello.ServerName - if c.hostname == "" { - c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String()) + if hello != nil { + out.Hostname = hello.ServerName + } + return out, nil +} + +// readOnlyConn rejects writes/close and reads from a reader. +type readOnlyConn struct { + r io.Reader + raw net.Conn +} + +func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } +func (c readOnlyConn) Close() error { return nil } +func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() } +func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() } +func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil } +func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil } +func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil } + +// Conn wraps net.Conn with lazy protocol initialization. +type Conn struct { + net.Conn + + once sync.Once + initErr error + closeOnce sync.Once + + isTLS bool + tlsConn *tls.Conn + plainConn net.Conn + + hostname string + + baseTLSConfig *tls.Config + getConfigForClient GetConfigForClientFunc + allowNonTLS bool + sniffer Sniffer + sniffTimeout time.Duration + maxClientHello int + logger Logger + stats *Stats + skipSniff bool +} + +func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { + return &Conn{ + Conn: raw, + plainConn: raw, + baseTLSConfig: cfg.BaseTLSConfig, + getConfigForClient: cfg.GetConfigForClient, + allowNonTLS: cfg.AllowNonTLS, + sniffer: TLSSniffer{}, + sniffTimeout: cfg.SniffTimeout, + maxClientHello: cfg.MaxClientHelloBytes, + logger: cfg.Logger, + stats: stats, } - return isTLS, nil } func (c *Conn) init() { c.once.Do(func() { - if c.isOriginal { + if c.skipSniff { return } - if c.tlsCfg != nil { - isTLS, err := c.isTLSConnection() - if err != nil { - c.initErr = err - return - } - c.isTLS = isTLS + if c.baseTLSConfig == nil && c.getConfigForClient == nil { + c.isTLS = false + return } + if c.sniffTimeout > 0 { + _ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout)) + } + res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello) + if c.sniffTimeout > 0 { + _ = c.Conn.SetReadDeadline(time.Time{}) + } + if err != nil { + c.initErr = err + c.failAndClose("sniff failed: %v", err) + return + } + + c.isTLS = res.IsTLS + c.hostname = res.Hostname + if c.isTLS { - var cfg = c.tlsCfg - if c.getConfigForClient != nil { - cfg = c.getConfigForClient(c.hostname) - if cfg == nil { - cfg = c.tlsCfg - } + if c.stats != nil { + c.stats.incTLSDetected() } - c.tlsConn = tls.Server(&myConn{ - reader: c.buffer, - conn: c.Conn, - isReadOnly: false, - }, cfg) - } else { - if !c.allowNonTls { - c.initErr = net.ErrClosed + tlsCfg, errCfg := c.selectTLSConfig() + if errCfg != nil { + c.initErr = errCfg + c.failAndClose("tls config select failed: %v", errCfg) return } - c.noTlsReader = io.MultiReader(c.buffer, c.Conn) + rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) + c.tlsConn = tls.Server(rc, tlsCfg) + return } + + if c.stats != nil { + c.stats.incPlainDetected() + } + if !c.allowNonTLS { + c.initErr = ErrNonTLSNotAllowed + c.failAndClose("plain tcp rejected") + return + } + c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) }) } +func (c *Conn) failAndClose(format string, v ...interface{}) { + if c.stats != nil { + c.stats.incInitFailures() + } + if c.logger != nil { + c.logger.Printf("starnet: "+format, v...) + } + _ = c.Close() +} + +func (c *Conn) selectTLSConfig() (*tls.Config, error) { + if c.getConfigForClient != nil { + cfg, err := c.getConfigForClient(c.hostname) + if err != nil { + return nil, err + } + if cfg != nil { + return cfg, nil + } + } + if c.baseTLSConfig != nil { + return c.baseTLSConfig, nil + } + return nil, ErrNoTLSConfig +} + +// Hostname returns sniffed SNI hostname (if any). +func (c *Conn) Hostname() string { + c.init() + return c.hostname +} + +func (c *Conn) IsTLS() bool { + c.init() + return c.initErr == nil && c.isTLS +} + +func (c *Conn) TLSConn() (*tls.Conn, error) { + c.init() + if c.initErr != nil { + return nil, c.initErr + } + if !c.isTLS || c.tlsConn == nil { + return nil, ErrNotTLS + } + return c.tlsConn, nil +} + func (c *Conn) Read(b []byte) (int, error) { c.init() if c.initErr != nil { @@ -281,7 +243,7 @@ func (c *Conn) Read(b []byte) (int, error) { if c.isTLS { return c.tlsConn.Read(b) } - return c.noTlsReader.Read(b) + return c.plainConn.Read(b) } func (c *Conn) Write(b []byte) (int, error) { @@ -289,113 +251,250 @@ func (c *Conn) Write(b []byte) (int, error) { if c.initErr != nil { return 0, c.initErr } - if c.isTLS { return c.tlsConn.Write(b) } - return c.Conn.Write(b) + return c.plainConn.Write(b) } func (c *Conn) Close() error { - if c.isTLS && c.tlsConn != nil { - return c.tlsConn.Close() - } - return c.Conn.Close() + var err error + c.closeOnce.Do(func() { + if c.tlsConn != nil { + err = c.tlsConn.Close() + } else { + err = c.Conn.Close() + } + if c.stats != nil { + c.stats.incClosed() + } + }) + return err } func (c *Conn) SetDeadline(t time.Time) error { + c.init() + if c.initErr != nil { + return c.initErr + } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetDeadline(t) } - return c.Conn.SetDeadline(t) + return c.plainConn.SetDeadline(t) } func (c *Conn) SetReadDeadline(t time.Time) error { + c.init() + if c.initErr != nil { + return c.initErr + } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetReadDeadline(t) } - return c.Conn.SetReadDeadline(t) + return c.plainConn.SetReadDeadline(t) } func (c *Conn) SetWriteDeadline(t time.Time) error { + c.init() + if c.initErr != nil { + return c.initErr + } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetWriteDeadline(t) } - return c.Conn.SetWriteDeadline(t) + return c.plainConn.SetWriteDeadline(t) } -func (c *Conn) TlsConnection() (*tls.Conn, error) { - if c.initErr != nil { - return nil, c.initErr - } - if !c.isTLS { - return nil, net.ErrClosed - } - return c.tlsConn, nil +// Listener wraps net.Listener and returns starnet.Conn from Accept. +type Listener struct { + net.Listener + + mu sync.RWMutex + cfg ListenerConfig + stats Stats } -func (c *Conn) OriginalConn() net.Conn { - return c.Conn -} - -func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { - if conn == nil { - return nil, net.ErrClosed - } - c := &Conn{ - Conn: conn, - isTLS: true, - tlsCfg: cfg, - tlsConn: tls.Client(conn, cfg), - isOriginal: true, - } - return c, nil -} - -func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { - if conn == nil { - return nil, net.ErrClosed - } - c := &Conn{ - Conn: conn, - isTLS: true, - tlsCfg: cfg, - tlsConn: tls.Server(conn, cfg), - isOriginal: true, - } - c.init() - return c, nil -} - -func Dial(network, address string) (*Conn, error) { - conn, err := net.Dial(network, address) +// Listen creates a plain listener config (no TLS detection). +func Listen(network, address string) (*Listener, error) { + ln, err := net.Listen(network, address) if err != nil { return nil, err } + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + cfg.BaseTLSConfig = nil + cfg.GetConfigForClient = nil + return &Listener{Listener: ln, cfg: cfg}, nil +} + +// ListenWithConfig creates a listener with full config. +func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) { + ln, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil +} + +// ListenWithListenConfig creates listener using net.ListenConfig. +func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) { + ln, err := lc.Listen(context.Background(), network, address) + if err != nil { + return nil, err + } + return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil +} + +// ListenTLS creates TLS listener from cert/key paths. +func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = allowNonTLS + cfg.BaseTLSConfig = TLSDefaults() + cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert} + return ListenWithConfig(network, address, cfg) +} + +func normalizeConfig(cfg ListenerConfig) ListenerConfig { + out := DefaultListenerConfig() + out.AllowNonTLS = cfg.AllowNonTLS + out.SniffTimeout = cfg.SniffTimeout + out.MaxClientHelloBytes = cfg.MaxClientHelloBytes + out.BaseTLSConfig = cfg.BaseTLSConfig + out.GetConfigForClient = cfg.GetConfigForClient + out.Logger = cfg.Logger + if out.MaxClientHelloBytes <= 0 { + out.MaxClientHelloBytes = 64 * 1024 + } + return out +} + +// SetConfig atomically replaces listener config for new accepted connections. +func (l *Listener) SetConfig(cfg ListenerConfig) { + l.mu.Lock() + l.cfg = normalizeConfig(cfg) + l.mu.Unlock() +} + +// Config returns a copy of current config. +func (l *Listener) Config() ListenerConfig { + l.mu.RLock() + cfg := l.cfg + l.mu.RUnlock() + return cfg +} + +// Stats returns current counters snapshot. +func (l *Listener) Stats() StatsSnapshot { + return l.stats.Snapshot() +} + +func (l *Listener) Accept() (net.Conn, error) { + raw, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.stats.incAccepted() + + l.mu.RLock() + cfg := l.cfg + l.mu.RUnlock() + + return newConn(raw, cfg, &l.stats), nil +} + +// AcceptContext supports cancellation by closing accepted conn when ctx is done early. +func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) { + type result struct { + c net.Conn + err error + } + ch := make(chan result, 1) + go func() { + c, err := l.Accept() + ch <- result{c: c, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-ch: + return r.c, r.err + } +} + +// Dial creates a plain TCP starnet.Conn. +func Dial(network, address string) (*Conn, error) { + raw, err := net.Dial(network, address) + if err != nil { + return nil, err + } + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + cfg.BaseTLSConfig = nil + cfg.GetConfigForClient = nil + c := newConn(raw, cfg, nil) + c.isTLS = false + return c, nil +} + +// DialWithConfig dials with net.Dialer options. +func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) { + d := net.Dialer{ + Timeout: dc.Timeout, + LocalAddr: dc.LocalAddr, + } + raw, err := d.Dial(network, address) + if err != nil { + return nil, err + } + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + c := newConn(raw, cfg, nil) + c.isTLS = false + return c, nil +} + +// DialTLSWithConfig creates a TLS client connection wrapper. +func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) { + d := net.Dialer{Timeout: timeout} + raw, err := d.Dial(network, address) + if err != nil { + return nil, err + } + tc := tls.Client(raw, tlsCfg) return &Conn{ - Conn: conn, - isTLS: false, - tlsCfg: nil, - tlsConn: nil, - noTlsReader: conn, - isOriginal: true, + Conn: raw, + plainConn: raw, + isTLS: true, + tlsConn: tc, + hostname: "", + initErr: nil, + allowNonTLS: false, + skipSniff: true, }, nil } -func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) { - config, err := tls.LoadX509KeyPair(certFile, keyFile) +// DialTLS creates TLS client conn from cert/key paths. +func DialTLS(network, address, certFile, keyFile string) (*Conn, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{config}, - } - - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - - return NewClientTlsConn(conn, tlsConfig) + cfg := TLSDefaults() + cfg.Certificates = []tls.Certificate{cert} + return DialTLSWithConfig(network, address, cfg, 0) +} + +func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) { + if listener == nil { + return nil, ErrNilConn + } + return &Listener{ + Listener: listener, + cfg: normalizeConfig(cfg), + }, nil } diff --git a/tlssniffer_test.go b/tlssniffer_test.go new file mode 100644 index 0000000..6b523fe --- /dev/null +++ b/tlssniffer_test.go @@ -0,0 +1,691 @@ +package starnet + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "io" + "math/big" + "net" + "os" + "sync" + "testing" + "time" +) + +// ---------- cert helpers ---------- + +func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + if err != nil { + t.Fatalf("serial: %v", err) + } + + tpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: "starnet-test", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + DNSNames: dnsNames, + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("CreateCertificate: %v", err) + } + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + return certPEM, keyPEM +} + +func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate { + t.Helper() + certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair: %v", err) + } + return cert +} + +func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) { + t.Helper() + + certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) + + cf, err := os.CreateTemp("", "starnet-cert-*.pem") + if err != nil { + t.Fatalf("CreateTemp cert: %v", err) + } + kf, err := os.CreateTemp("", "starnet-key-*.pem") + if err != nil { + _ = cf.Close() + _ = os.Remove(cf.Name()) + t.Fatalf("CreateTemp key: %v", err) + } + + if _, err := cf.Write(certPEM); err != nil { + t.Fatalf("write cert: %v", err) + } + if _, err := kf.Write(keyPEM); err != nil { + t.Fatalf("write key: %v", err) + } + _ = cf.Close() + _ = kf.Close() + + return cf.Name(), kf.Name(), func() { + _ = os.Remove(cf.Name()) + _ = os.Remove(kf.Name()) + } +} + +// ---------- server helpers ---------- + +func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) { + t.Helper() + + ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("ListenWithConfig: %v", err) + } + + var wg sync.WaitGroup + stop := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + for { + c, err := ln.Accept() + if err != nil { + select { + case <-stop: + return + default: + return + } + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(c) + } + }() + + cleanup := func() { + close(stop) + _ = ln.Close() + wg.Wait() + } + return ln, ln.Addr().String(), cleanup +} + +// ---------- tests ---------- + +func TestListen(t *testing.T) { + ln, err := Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + c, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + msg := []byte("x") + if _, err := c.Write(msg); err != nil { + t.Fatalf("write: %v", err) + } + buf := make([]byte, 1) + if _, err := io.ReadFull(c, buf); err != nil { + t.Fatalf("read: %v", err) + } +} + +func TestListenWithListenConfig(t *testing.T) { + lc := net.ListenConfig{} + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + + ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("ListenWithListenConfig: %v", err) + } + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + c, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer c.Close() + + msg := []byte("ok") + if _, err := c.Write(msg); err != nil { + t.Fatalf("write: %v", err) + } + buf := make([]byte, 2) + if _, err := io.ReadFull(c, buf); err != nil { + t.Fatalf("read: %v", err) + } +} + +func TestListenerSetConfig(t *testing.T) { + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + + ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + cfg2 := cfg + cfg2.SniffTimeout = time.Second + ln.SetConfig(cfg2) + + got := ln.Config() + if got.SniffTimeout != time.Second { + t.Fatalf("SetConfig not applied") + } +} + +func TestPlainAllowed(t *testing.T) { + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + cfg.BaseTLSConfig = nil + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + msg := []byte("hello-plain") + if _, err := c.Write(msg); err != nil { + t.Fatalf("write: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(c, buf); err != nil { + t.Fatalf("read: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg)) + } +} + +func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + + _, _ = c.Write([]byte("plain")) + _ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + b := make([]byte, 1) + _, err = c.Read(b) + if err == nil { + t.Fatalf("expected read error due to non-tls rejection") + } +} + +func TestTLSHandshakeAndEcho(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + MinVersion: tls.VersionTLS12, + }) + if err != nil { + t.Fatalf("tls dial: %v", err) + } + defer tc.Close() + + msg := []byte("hello-tls") + if _, err := tc.Write(msg); err != nil { + t.Fatalf("tls write: %v", err) + } + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(tc, buf); err != nil { + t.Fatalf("tls read: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg)) + } +} + +func TestDynamicConfigBySNI(t *testing.T) { + certA := genSelfSignedCert(t, "a.local") + certB := genSelfSignedCert(t, "b.local") + + base := TLSDefaults() + base.Certificates = []tls.Certificate{certA} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + cfg.GetConfigForClient = func(host string) (*tls.Config, error) { + if host == "b.local" { + b := TLSDefaults() + b.Certificates = []tls.Certificate{certB} + return b, nil + } + return nil, nil + } + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "b.local", + MinVersion: tls.VersionTLS12, + }) + if err != nil { + t.Fatalf("tls dial: %v", err) + } + defer tc.Close() + + if !tc.ConnectionState().HandshakeComplete { + t.Fatalf("handshake not complete") + } +} + +func TestGetConfigForClientError(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + cfg.GetConfigForClient = func(host string) (*tls.Config, error) { + return nil, errors.New("boom") + } + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + _, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }) + if err == nil { + t.Fatalf("expected tls dial failure due to selector error") + } +} + +func TestAcceptContextCancel(t *testing.T) { + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + + ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = ln.AcceptContext(ctx) + if err == nil { + t.Fatalf("expected context timeout/cancel") + } +} + +func TestListenerStats(t *testing.T) { + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = true + + ln, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + _, _ = c.Write([]byte("x")) + _ = c.Close() + + time.Sleep(100 * time.Millisecond) + + s := ln.Stats() + if s.Accepted == 0 { + t.Fatalf("expected accepted > 0") + } +} + +func TestDialAndDialWithConfig(t *testing.T) { + nl, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer nl.Close() + + go func() { + c, err := nl.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + c1, err := Dial("tcp", nl.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c1.Close() + + msg := []byte("abc") + if _, err := c1.Write(msg); err != nil { + t.Fatalf("c1 write: %v", err) + } + got := make([]byte, 3) + if _, err := io.ReadFull(c1, got); err != nil { + t.Fatalf("c1 read: %v", err) + } + + c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second}) + if err != nil { + t.Fatalf("DialWithConfig: %v", err) + } + defer c2.Close() +} + +func TestListenTLS_FileAPI(t *testing.T) { + certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost") + defer cleanupFiles() + + ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false) + if err != nil { + t.Fatalf("ListenTLS: %v", err) + } + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + }) + if err != nil { + t.Fatalf("tls dial: %v", err) + } + defer tc.Close() + + msg := []byte("hi") + if _, err := tc.Write(msg); err != nil { + t.Fatalf("tls write: %v", err) + } + out := make([]byte, 2) + if _, err := io.ReadFull(tc, out); err != nil { + t.Fatalf("tls read: %v", err) + } +} + +func TestDialTLSWithConfig(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + + cfg := DefaultListenerConfig() + cfg.BaseTLSConfig = base + cfg.AllowNonTLS = false + + ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + clientCfg := &tls.Config{ + InsecureSkipVerify: true, + ServerName: "localhost", + } + c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second) + if err != nil { + t.Fatalf("DialTLSWithConfig: %v", err) + } + defer c.Close() + + if !c.IsTLS() { + t.Fatalf("expected IsTLS true") + } +} + +func TestDialTLS_FileAPI(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + cfg := DefaultListenerConfig() + cfg.BaseTLSConfig = base + cfg.AllowNonTLS = false + + ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + _, _ = io.Copy(c, c) + }() + + clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost") + defer cleanupFiles() + + c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile) + if err != nil { + t.Fatalf("DialTLS: %v", err) + } + defer c.Close() + + if !c.IsTLS() { + t.Fatalf("expected IsTLS true") + } +} + +func TestConnIsTLS_PlainAndTLS(t *testing.T) { + // ---- plain case ---- + plainCfg := DefaultListenerConfig() + plainCfg.AllowNonTLS = true + + ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg) + if err != nil { + t.Fatalf("listen1: %v", err) + } + defer ln1.Close() + + plainDone := make(chan *Conn, 1) + plainErr := make(chan error, 1) + + go func() { + nc, err := ln1.Accept() + if err != nil { + plainErr <- err + return + } + sc, ok := nc.(*Conn) + if !ok { + _ = nc.Close() + plainErr <- errors.New("accepted conn is not *Conn") + return + } + plainDone <- sc + + // block until client sends one byte, then close + buf := make([]byte, 1) + _, _ = sc.Read(buf) + _ = sc.Close() + }() + + c1, err := net.Dial("tcp", ln1.Addr().String()) + if err != nil { + t.Fatalf("dial1: %v", err) + } + if _, err := c1.Write([]byte("p")); err != nil { + _ = c1.Close() + t.Fatalf("plain client write: %v", err) + } + _ = c1.Close() + + select { + case err := <-plainErr: + t.Fatalf("plain server error: %v", err) + case sc1 := <-plainDone: + if sc1.IsTLS() { + t.Fatalf("plain conn should not be TLS") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting plain side") + } + + // ---- tls case ---- + cert := genSelfSignedCert(t, "localhost") + tlsBase := TLSDefaults() + tlsBase.Certificates = []tls.Certificate{cert} + + tlsCfg := DefaultListenerConfig() + tlsCfg.BaseTLSConfig = tlsBase + tlsCfg.AllowNonTLS = false + + ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg) + if err != nil { + t.Fatalf("listen2: %v", err) + } + defer ln2.Close() + + tlsDone := make(chan *Conn, 1) + tlsErr := make(chan error, 1) + + go func() { + nc, err := ln2.Accept() + if err != nil { + tlsErr <- err + return + } + sc, ok := nc.(*Conn) + if !ok { + _ = nc.Close() + tlsErr <- errors.New("accepted conn is not *Conn") + return + } + tlsDone <- sc + + // key point: wait for real data to ensure TLS handshake/path is executed + buf := make([]byte, 1) + _, _ = sc.Read(buf) + _ = sc.Close() + }() + + d := &net.Dialer{Timeout: 2 * time.Second} + tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, // test only + ServerName: "localhost", + MinVersion: tls.VersionTLS12, + }) + if err != nil { + t.Fatalf("tls dial: %v", err) + } + if _, err := tc.Write([]byte("t")); err != nil { + _ = tc.Close() + t.Fatalf("tls client write: %v", err) + } + _ = tc.Close() + + select { + case err := <-tlsErr: + t.Fatalf("tls server error: %v", err) + case sc2 := <-tlsDone: + if !sc2.IsTLS() { + t.Fatalf("tls conn should be TLS") + } + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting tls side") + } +} diff --git a/tlsstats.go b/tlsstats.go new file mode 100644 index 0000000..01ad51e --- /dev/null +++ b/tlsstats.go @@ -0,0 +1,43 @@ +package starnet + +import "sync/atomic" + +// StatsSnapshot is a read-only copy of runtime counters. +type StatsSnapshot struct { + Accepted uint64 + TLSDetected uint64 + PlainDetected uint64 + InitFailures uint64 + Closed uint64 +} + +// Stats provides lock-free counters. +type Stats struct { + accepted uint64 + tlsDetected uint64 + plainDetected uint64 + initFailures uint64 + closed uint64 +} + +func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) } +func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) } +func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) } +func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) } +func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) } + +// Snapshot returns a stable view of counters. +func (s *Stats) Snapshot() StatsSnapshot { + return StatsSnapshot{ + Accepted: atomic.LoadUint64(&s.accepted), + TLSDetected: atomic.LoadUint64(&s.tlsDetected), + PlainDetected: atomic.LoadUint64(&s.plainDetected), + InitFailures: atomic.LoadUint64(&s.initFailures), + Closed: atomic.LoadUint64(&s.closed), + } +} + +// Logger is a minimal logging abstraction. +type Logger interface { + Printf(format string, v ...interface{}) +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..504a9d9 --- /dev/null +++ b/transport.go @@ -0,0 +1,97 @@ +package starnet + +import ( + "net/http" + "net/url" + "sync" + "time" +) + +// Transport 自定义 Transport(支持请求级配置) +type Transport struct { + base *http.Transport + mu sync.RWMutex +} + +// RoundTrip 实现 http.RoundTripper 接口 +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + // 确保 base 已初始化 + if t.base == nil { + t.mu.Lock() + if t.base == nil { + t.base = &http.Transport{ + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + } + t.mu.Unlock() + } + + // 提取请求级别的配置 + reqCtx := getRequestContext(req.Context()) + + // 优先级1:完全自定义的 transport + if reqCtx.Transport != nil { + return reqCtx.Transport.RoundTrip(req) + } + + // 优先级2:需要动态配置 + if needsDynamicTransport(reqCtx) { + dynamicTransport := t.buildDynamicTransport(reqCtx) + return dynamicTransport.RoundTrip(req) + } + + // 优先级3:使用基础 transport + t.mu.RLock() + defer t.mu.RUnlock() + return t.base.RoundTrip(req) +} + +// buildDynamicTransport 构建动态 Transport +func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport { + t.mu.RLock() + transport := t.base.Clone() + t.mu.RUnlock() + + // 应用 TLS 配置(即使为 nil 也要检查 SkipVerify) + if rc.TLSConfig != nil { + transport.TLSClientConfig = rc.TLSConfig + } + + // 应用代理配置 + if rc.Proxy != "" { + proxyURL, err := url.Parse(rc.Proxy) + if err == nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + } + + // 应用自定义 Dial 函数 + if rc.DialFn != nil { + transport.DialContext = rc.DialFn + } else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil { + // 使用默认 Dial 函数(会从 context 读取配置) + transport.DialContext = defaultDialFunc + transport.DialTLSContext = defaultDialTLSFunc + } + + return transport +} + +// Base 获取基础 Transport +func (t *Transport) Base() *http.Transport { + t.mu.RLock() + defer t.mu.RUnlock() + return t.base +} + +// SetBase 设置基础 Transport +func (t *Transport) SetBase(base *http.Transport) { + t.mu.Lock() + t.base = base + t.mu.Unlock() +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..11c5e0f --- /dev/null +++ b/types.go @@ -0,0 +1,131 @@ +package starnet + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "time" +) + +// HTTP Content-Type 常量 +const ( + ContentTypeFormURLEncoded = "application/x-www-form-urlencoded" + ContentTypeFormData = "multipart/form-data" + ContentTypeJSON = "application/json" + ContentTypeXML = "application/xml" + ContentTypePlain = "text/plain" + ContentTypeHTML = "text/html" + ContentTypeOctetStream = "application/octet-stream" +) + +// 默认配置 +const ( + DefaultDialTimeout = 5 * time.Second + DefaultTimeout = 10 * time.Second + DefaultUserAgent = "Starnet/1.0.0" + DefaultFetchRespBody = false +) + +// RequestFile 表示要上传的文件 +type RequestFile struct { + FormName string // 表单字段名 + FileName string // 文件名 + FilePath string // 文件路径(如果从文件读取) + FileData io.Reader // 文件数据流 + FileSize int64 // 文件大小 + FileType string // MIME 类型 +} + +// UploadProgressFunc 文件上传进度回调函数 +type UploadProgressFunc func(filename string, uploaded int64, total int64) + +// NetworkConfig 网络配置 +type NetworkConfig struct { + Proxy string // 代理地址 + DialTimeout time.Duration // 连接超时 + Timeout time.Duration // 总超时 + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// TLSConfig TLS 配置 +type TLSConfig struct { + Config *tls.Config // TLS 配置 + SkipVerify bool // 跳过证书验证 +} + +// DNSConfig DNS 配置 +type DNSConfig struct { + CustomIP []string // 直接指定 IP(最高优先级) + CustomDNS []string // 自定义 DNS 服务器 + LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数 +} + +// BodyConfig 请求体配置 +type BodyConfig struct { + Bytes []byte // 原始字节 + Reader io.Reader // 数据流 + FormData map[string][]string // 表单数据 + Files []RequestFile // 文件列表 +} + +// RequestConfig 请求配置(内部使用) +type RequestConfig struct { + Network NetworkConfig + TLS TLSConfig + DNS DNSConfig + Body BodyConfig + Headers http.Header + Cookies []*http.Cookie + Queries map[string][]string + + // 其他配置 + BasicAuth [2]string // Basic 认证 + ContentLength int64 // 手动设置的 Content-Length + AutoCalcContentLength bool // 自动计算 Content-Length + UploadProgress UploadProgressFunc // 上传进度回调 + + // Transport 配置 + CustomTransport bool // 是否使用自定义 Transport + Transport *http.Transport // 自定义 Transport +} + +// Clone 克隆配置 +func (c *RequestConfig) Clone() *RequestConfig { + return &RequestConfig{ + Network: NetworkConfig{ + Proxy: c.Network.Proxy, + DialTimeout: c.Network.DialTimeout, + Timeout: c.Network.Timeout, + DialFunc: c.Network.DialFunc, + }, + TLS: TLSConfig{ + Config: cloneTLSConfig(c.TLS.Config), + SkipVerify: c.TLS.SkipVerify, + }, + DNS: DNSConfig{ + CustomIP: cloneStringSlice(c.DNS.CustomIP), + CustomDNS: cloneStringSlice(c.DNS.CustomDNS), + LookupFunc: c.DNS.LookupFunc, + }, + Body: BodyConfig{ + Bytes: cloneBytes(c.Body.Bytes), + Reader: c.Body.Reader, // Reader 不可克隆 + FormData: cloneStringMapSlice(c.Body.FormData), + Files: cloneFiles(c.Body.Files), + }, + Headers: cloneHeader(c.Headers), + Cookies: cloneCookies(c.Cookies), + Queries: cloneStringMapSlice(c.Queries), + BasicAuth: c.BasicAuth, + ContentLength: c.ContentLength, + AutoCalcContentLength: c.AutoCalcContentLength, + UploadProgress: c.UploadProgress, + CustomTransport: c.CustomTransport, + Transport: c.Transport, // Transport 共享 + } +} + +// RequestOpt 请求选项函数 +type RequestOpt func(*Request) error diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..5ffc470 --- /dev/null +++ b/utils.go @@ -0,0 +1,212 @@ +package starnet + +import ( + "context" + "crypto/tls" + "io" + "net/http" + "net/url" + "strings" +) + +// validMethod 验证 HTTP 方法是否有效 +func validMethod(method string) bool { + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// isNotToken 检查字符是否不是 token 字符 +func isNotToken(r rune) bool { + return !isTokenRune(r) +} + +// isTokenRune 检查字符是否是 token 字符 +func isTokenRune(r rune) bool { + i := int(r) + return i < 127 && isTokenTable[i] +} + +// isTokenTable token 字符表 +var isTokenTable = [127]bool{ + '!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true, + '+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true, + '4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true, + 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, + 'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, + 'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true, + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, + 'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, + 'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true, + 'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true, +} + +// hasPort 检查地址是否包含端口 +func hasPort(s string) bool { + return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") +} + +// removeEmptyPort 移除空端口 +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +// UrlEncode URL 编码 +func UrlEncode(str string) string { + return url.QueryEscape(str) +} + +// UrlEncodeRaw URL 编码(空格编码为 %20) +func UrlEncodeRaw(str string) string { + return strings.Replace(url.QueryEscape(str), "+", "%20", -1) +} + +// UrlDecode URL 解码 +func UrlDecode(str string) (string, error) { + return url.QueryUnescape(str) +} + +// BuildQuery 构建查询字符串 +func BuildQuery(data map[string]string) string { + query := url.Values{} + for k, v := range data { + query.Add(k, v) + } + return query.Encode() +} + +// BuildPostForm 构建 POST 表单数据 +func BuildPostForm(data map[string]string) []byte { + return []byte(BuildQuery(data)) +} + +// cloneHeader 克隆 Header +func cloneHeader(h http.Header) http.Header { + if h == nil { + return make(http.Header) + } + newHeader := make(http.Header, len(h)) + for k, v := range h { + newHeader[k] = append([]string(nil), v...) + } + return newHeader +} + +// cloneCookies 克隆 Cookies +func cloneCookies(cookies []*http.Cookie) []*http.Cookie { + if cookies == nil { + return nil + } + newCookies := make([]*http.Cookie, len(cookies)) + for i, c := range cookies { + newCookies[i] = &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: c.Expires, + RawExpires: c.RawExpires, + MaxAge: c.MaxAge, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + SameSite: c.SameSite, + Raw: c.Raw, + Unparsed: append([]string(nil), c.Unparsed...), + } + } + return newCookies +} + +// cloneStringMapSlice 克隆 map[string][]string +func cloneStringMapSlice(m map[string][]string) map[string][]string { + if m == nil { + return make(map[string][]string) + } + newMap := make(map[string][]string, len(m)) + for k, v := range m { + newMap[k] = append([]string(nil), v...) + } + return newMap +} + +// cloneBytes 克隆字节切片 +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + newBytes := make([]byte, len(b)) + copy(newBytes, b) + return newBytes +} + +// cloneStringSlice 克隆字符串切片 +func cloneStringSlice(s []string) []string { + if s == nil { + return nil + } + newSlice := make([]string, len(s)) + copy(newSlice, s) + return newSlice +} + +// cloneFiles 克隆文件列表 +func cloneFiles(files []RequestFile) []RequestFile { + if files == nil { + return nil + } + newFiles := make([]RequestFile, len(files)) + copy(newFiles, files) + return newFiles +} + +// cloneTLSConfig 克隆 TLS 配置 +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return nil + } + return cfg.Clone() +} + +// copyWithProgress 带进度的复制 +func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) { + if progress == nil { + return io.Copy(dst, src) + } + + var written int64 + buf := make([]byte, 32*1024) // 32KB buffer + + for { + select { + case <-ctx.Done(): + return written, ctx.Err() + default: + } + + nr, err := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + // 同步调用进度回调(不使用 goroutine) + progress(filename, written, total) + } + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if err != nil { + if err == io.EOF { + // 最后一次进度回调 + progress(filename, written, total) + return written, nil + } + return written, err + } + } +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..4cbc71a --- /dev/null +++ b/utils_test.go @@ -0,0 +1,284 @@ +package starnet + +import ( + "net/http" + "testing" + "time" +) + +func TestUrlEncodeRaw(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "basic string with space", + input: "hello world", + expected: "hello%20world", + }, + { + name: "special characters", + input: "hello world!@#$%^&*()_+-=~`", + expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "chinese characters", + input: "你好世界", + expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UrlEncodeRaw(tt.input) + if result != tt.expected { + t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestUrlEncode(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "space encoded as plus", + input: "hello world", + expected: "hello+world", + }, + { + name: "special characters", + input: "hello world!@#$%^&*()_+-=~`", + expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UrlEncode(tt.input) + if result != tt.expected { + t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestUrlDecode(t *testing.T) { + tests := []struct { + name string + input string + expected string + expectErr bool + }{ + { + name: "basic decode", + input: "hello%20world", + expected: "hello world", + expectErr: false, + }, + { + name: "plus to space", + input: "hello+world", + expected: "hello world", + expectErr: false, + }, + { + name: "special characters", + input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60", + expected: "hello world!@#$%^&*()_+-=~`", + expectErr: false, + }, + { + name: "invalid encoding", + input: "%zz", + expected: "", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := UrlDecode(tt.input) + if tt.expectErr { + if err == nil { + t.Errorf("UrlDecode(%q) expected error, got nil", tt.input) + } + } else { + if err != nil { + t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err) + } + if result != tt.expected { + t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected) + } + } + }) + } +} + +func TestBuildQuery(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected string + }{ + { + name: "single parameter", + input: map[string]string{ + "key": "value", + }, + expected: "key=value", + }, + { + name: "empty map", + input: map[string]string{}, + expected: "", + }, + { + name: "nil map", + input: nil, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildQuery(tt.input) + if result != tt.expected { + t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestBuildPostForm(t *testing.T) { + tests := []struct { + name string + input map[string]string + expected []byte + }{ + { + name: "basic form", + input: map[string]string{ + "key1": "value1", + }, + expected: []byte("key1=value1"), + }, + { + name: "empty map", + input: map[string]string{}, + expected: []byte(""), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildPostForm(tt.input) + if string(result) != string(tt.expected) { + t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestValidMethod(t *testing.T) { + tests := []struct { + name string + method string + expected bool + }{ + {"GET", "GET", true}, + {"POST", "POST", true}, + {"PUT", "PUT", true}, + {"DELETE", "DELETE", true}, + {"PATCH", "PATCH", true}, + {"OPTIONS", "OPTIONS", true}, + {"HEAD", "HEAD", true}, + {"TRACE", "TRACE", true}, + {"CONNECT", "CONNECT", true}, + {"invalid with space", "GET POST", false}, + {"invalid with special char", "GET<>", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := validMethod(tt.method) + if result != tt.expected { + t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected) + } + }) + } +} + +func TestCloneCookies_FullFields(t *testing.T) { + expire := time.Now().Add(2 * time.Hour) + + src := []*http.Cookie{ + { + Name: "sid", + Value: "abc123", + Path: "/", + Domain: "example.com", + Expires: expire, + RawExpires: expire.UTC().Format(time.RFC1123), + MaxAge: 3600, + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Raw: "sid=abc123; Path=/; HttpOnly", + Unparsed: []string{"Priority=High", "Partitioned"}, + }, + } + + got := cloneCookies(src) + if got == nil || len(got) != 1 { + t.Fatalf("cloneCookies() len=%v; want 1", len(got)) + } + + // 指针应不同(不是浅拷贝) + if got[0] == src[0] { + t.Fatal("cookie pointer should be different (deep copy expected)") + } + + // 字段值应一致 + s := src[0] + g := got[0] + if g.Name != s.Name || + g.Value != s.Value || + g.Path != s.Path || + g.Domain != s.Domain || + !g.Expires.Equal(s.Expires) || + g.RawExpires != s.RawExpires || + g.MaxAge != s.MaxAge || + g.Secure != s.Secure || + g.HttpOnly != s.HttpOnly || + g.SameSite != s.SameSite || + g.Raw != s.Raw { + t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s) + } + + // Unparsed 内容一致 + if len(g.Unparsed) != len(s.Unparsed) { + t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed)) + } + for i := range s.Unparsed { + if g.Unparsed[i] != s.Unparsed[i] { + t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i]) + } + } + + // 验证 Unparsed 是深拷贝(修改源不影响目标) + src[0].Unparsed[0] = "Modified=Yes" + if got[0].Unparsed[0] == "Modified=Yes" { + t.Fatal("Unparsed should be deep-copied, but was affected by source mutation") + } +}