From b5bd7595a15e0e39b2b244bb3eec8ef49d447729 Mon Sep 17 00:00:00 2001 From: starainrt Date: Thu, 19 Mar 2026 16:42:45 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BC=98=E5=8C=96ping=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=202.=20=E6=96=B0=E5=A2=9E=E9=87=8D=E8=AF=95=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=203.=20=E4=BC=98=E5=8C=96=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 +- LICENSE | 201 ++++++++++++++ README.md | 106 ++++++++ context.go | 6 +- defensive_copy_test.go | 59 ++++ dialer.go | 21 +- errors.go | 192 +++++++++++++ errors_classify_test.go | 116 ++++++++ errors_kind_test.go | 49 ++++ example_test.go | 56 ++++ options.go | 21 +- ping.go | 522 ++++++++++++++++++++++++++++++++---- ping_logic_test.go | 214 +++++++++++++++ request.go | 58 +++- request_body.go | 157 +++++------ request_config.go | 21 +- request_header.go | 2 +- request_methods_ext_test.go | 43 +++ response.go | 36 ++- response_limit_test.go | 84 ++++++ retry.go | 423 +++++++++++++++++++++++++++++ retry_test.go | 298 ++++++++++++++++++++ timeout_refactor_test.go | 115 ++++++++ types.go | 2 + 24 files changed, 2645 insertions(+), 162 deletions(-) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 defensive_copy_test.go create mode 100644 errors_classify_test.go create mode 100644 errors_kind_test.go create mode 100644 ping_logic_test.go create mode 100644 request_methods_ext_test.go create mode 100644 response_limit_test.go create mode 100644 retry.go create mode 100644 retry_test.go create mode 100644 timeout_refactor_test.go diff --git a/.gitignore b/.gitignore index 723ef36..fa67efa 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -.idea \ No newline at end of file +.idea +.sentrux/ +agent_readme.md +target.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e8856d9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 starnet contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..6123cb4 --- /dev/null +++ b/README.md @@ -0,0 +1,106 @@ +# starnet + +`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities. + +## Highlights + +- Request-level timeout by context (without mutating shared `http.Client` timeout) +- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config +- Built-in retry with replay safety checks and configurable backoff/jitter/statuses +- Response body safety guard via max body bytes limit +- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`) +- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios +- ICMP ping with IPv4/IPv6 target handling and option-based probing API + +## Main Features + +### HTTP Client and Request + +- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods +- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect` +- Request body helpers: JSON, form data, multipart file upload, stream body +- Header/cookie/query helpers with defensive copy on key setters +- Request cloning for safe reuse in concurrent or variant calls + +### Timeout and Retry + +- Request timeout is applied by context deadline, not global client timeout +- Retry supports: + - max attempts + - backoff factor/base/max + - jitter + - retry status whitelist + - idempotent-only guard + - custom retry-on-error callback +- Retry keeps original request pointer in final response for consistency + +### Response Handling + +- `Bytes/String/JSON/Reader` helpers +- optional auto-fetch mode +- configurable max response body bytes to prevent oversized reads + +### Ping Module + +- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable` +- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size +- explicit error semantics for permission/protocol/timeout/resolve failures + +## Install + +```bash +go get b612.me/starnet +``` + +## Quick Example + +```go +package main + +import ( + "fmt" + "net/http" + "time" + + "b612.me/starnet" +) + +func main() { + resp, err := starnet.Get( + "https://example.com", + starnet.WithTimeout(2*time.Second), + starnet.WithRetry(2, + starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2), + starnet.WithRetryJitter(0.1), + ), + starnet.WithMaxRespBodyBytes(1<<20), + ) + if err != nil { + fmt.Println("request failed:", starnet.ClassifyError(err), err) + return + } + defer resp.Close() + + fmt.Println("status:", resp.StatusCode) + _, _ = resp.Body().Bytes() + + ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{ + Count: 2, + Timeout: 2 * time.Second, + }) + fmt.Println("pingable:", ok, pingErr == nil) + + _ = http.MethodGet +} +``` + +## Stability Notes + +- Raw ICMP ping may require elevated privileges on some systems. +- Integration tests that rely on external network are environment-dependent. + +## License + +This project is licensed under the Apache License 2.0. +See [LICENSE](./LICENSE). + diff --git a/context.go b/context.go index b88023f..69e6458 100644 --- a/context.go +++ b/context.go @@ -78,7 +78,6 @@ func needsDynamicTransport(rc *RequestContext) bool { rc.Proxy != "" || rc.DialFn != nil || (rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) || - (rc.Timeout > 0 && rc.Timeout != DefaultTimeout) || len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.LookupIPFn != nil @@ -122,13 +121,10 @@ func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Con execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS) } - // 总是注入 DialTimeout 和 Timeout(与原始代码一致) + // 总是注入 DialTimeout(与原始代码一致) if config.Network.DialTimeout > 0 { execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout) } - if config.Network.Timeout > 0 { - execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout) - } // 注入 DNS 解析函数 if config.DNS.LookupFunc != nil { diff --git a/defensive_copy_test.go b/defensive_copy_test.go new file mode 100644 index 0000000..3d0fdcc --- /dev/null +++ b/defensive_copy_test.go @@ -0,0 +1,59 @@ +package starnet + +import ( + "net/http" + "testing" +) + +func TestWithRawRequestNil(t *testing.T) { + _, err := NewRequest("http://example.com", "GET", WithRawRequest(nil)) + if err == nil { + t.Fatal("expected error when WithRawRequest(nil)") + } +} + +func TestSetHeadersDefensiveCopy(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + headers := http.Header{ + "X-Test": []string{"v1"}, + } + + req.SetHeaders(headers) + headers.Set("X-Test", "v2") + + if got := req.GetHeader("X-Test"); got != "v1" { + t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1") + } +} + +func TestSetQueriesDefensiveCopy(t *testing.T) { + req := NewSimpleRequest("http://example.com", "GET") + queries := map[string][]string{ + "k": []string{"v1"}, + } + + req.SetQueries(queries) + queries["k"][0] = "v2" + queries["k"] = append(queries["k"], "v3") + + got := req.config.Queries["k"] + if len(got) != 1 || got[0] != "v1" { + t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got) + } +} + +func TestSetFormDataDefensiveCopy(t *testing.T) { + req := NewSimpleRequest("http://example.com", "POST") + form := map[string][]string{ + "name": []string{"alice"}, + } + + req.SetFormData(form) + form["name"][0] = "bob" + form["name"] = append(form["name"], "carol") + + got := req.config.Body.FormData["name"] + if len(got) != 1 || got[0] != "alice" { + t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got) + } +} diff --git a/dialer.go b/dialer.go index 0e32b0d..d1ea6ce 100644 --- a/dialer.go +++ b/dialer.go @@ -19,11 +19,6 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error dialTimeout = DefaultDialTimeout } - timeout := reqCtx.Timeout - if timeout == 0 { - timeout = DefaultTimeout - } - // 解析地址 host, port, err := net.SplitHostPort(addr) if err != nil { @@ -47,12 +42,13 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error ipAddrs, err = reqCtx.LookupIPFn(ctx, host) } else if len(reqCtx.CustomDNS) > 0 { // 使用自定义 DNS 服务器 + dialer := &net.Dialer{Timeout: dialTimeout} resolver := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { var lastErr error for _, dnsServer := range reqCtx.CustomDNS { - conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53")) + conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53")) if err != nil { lastErr = err continue @@ -78,19 +74,15 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error } // 尝试连接所有地址 + dialer := &net.Dialer{Timeout: dialTimeout} var lastErr error for _, addr := range addrs { - conn, err := net.DialTimeout(network, addr, dialTimeout) + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { lastErr = err continue } - // 设置总超时 - if timeout > 0 { - conn.SetDeadline(time.Now().Add(timeout)) - } - return conn, nil } @@ -131,6 +123,11 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er } // 执行 TLS 握手 + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + defer conn.SetDeadline(time.Time{}) + } + tlsConn := tls.Client(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { conn.Close() diff --git a/errors.go b/errors.go index 3bba190..582d51b 100644 --- a/errors.go +++ b/errors.go @@ -1,8 +1,14 @@ package starnet import ( + "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" + "net" + "net/url" + "strings" ) var ( @@ -32,6 +38,21 @@ var ( // ErrBodyAlreadyConsumed Body 已被消费 ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed") + + // ErrRespBodyTooLarge 响应体超过允许上限 + ErrRespBodyTooLarge = errors.New("starnet: response body too large") + + // ErrPingInvalidTimeout ping 超时参数无效 + ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout") + + // ErrPingPermissionDenied ping 需要更高权限(raw socket) + ErrPingPermissionDenied = errors.New("starnet: ping permission denied") + + // ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持 + ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported") + + // ErrPingNoResolvedTarget ping 目标无法解析为可用地址 + ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved") ) // wrapError 包装错误,添加上下文信息 @@ -56,3 +77,174 @@ var ( // ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available. ErrNoTLSConfig = errors.New("starnet: no TLS config available") ) + +// ErrorKind is a normalized high-level category for request errors. +type ErrorKind string + +const ( + ErrorKindNone ErrorKind = "none" + ErrorKindCanceled ErrorKind = "canceled" + ErrorKindTimeout ErrorKind = "timeout" + ErrorKindDNS ErrorKind = "dns" + ErrorKindTLS ErrorKind = "tls" + ErrorKindProxy ErrorKind = "proxy" + ErrorKindOther ErrorKind = "other" +) + +// IsCanceled reports whether err is a cancellation-related error. +func IsCanceled(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) { + return true + } + + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "context canceled") || + strings.Contains(msg, "operation was canceled") || + strings.Contains(msg, "request canceled") +} + +// ClassifyError maps low-level errors to a stable category for business handling. +func ClassifyError(err error) ErrorKind { + if err == nil { + return ErrorKindNone + } + if IsCanceled(err) { + return ErrorKindCanceled + } + if IsProxy(err) { + return ErrorKindProxy + } + if IsDNS(err) { + return ErrorKindDNS + } + if IsTLS(err) { + return ErrorKindTLS + } + if IsTimeout(err) { + return ErrorKindTimeout + } + return ErrorKindOther +} + +// IsTimeout reports whether err is a timeout-related error. +func IsTimeout(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + var uerr *url.Error + if errors.As(err, &uerr) && uerr.Timeout() { + return true + } + + var nerr net.Error + if errors.As(err, &nerr) && nerr.Timeout() { + return true + } + + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded") +} + +// IsDNS reports whether err is a DNS resolution related error. +func IsDNS(err error) bool { + if err == nil { + return false + } + + var derr *net.DNSError + if errors.As(err, &derr) { + return true + } + + msg := strings.ToLower(err.Error()) + if strings.Contains(msg, "no such host") || + strings.Contains(msg, "server misbehaving") || + strings.Contains(msg, "temporary failure in name resolution") { + return true + } + + return strings.Contains(msg, "lookup ") && + (strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout")) +} + +// IsTLS reports whether err is TLS/Certificate related. +func IsTLS(err error) bool { + if err == nil { + return false + } + if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) { + return true + } + + var recErr tls.RecordHeaderError + if errors.As(err, &recErr) { + return true + } + + var uaErr x509.UnknownAuthorityError + if errors.As(err, &uaErr) { + return true + } + + var hnErr x509.HostnameError + if errors.As(err, &hnErr) { + return true + } + + var certErr x509.CertificateInvalidError + if errors.As(err, &certErr) { + return true + } + + var rootsErr x509.SystemRootsError + if errors.As(err, &rootsErr) { + return true + } + + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:") +} + +// IsProxy reports whether err is proxy related. +func IsProxy(err error) bool { + if err == nil { + return false + } + + if isProxyMessage(strings.ToLower(err.Error())) { + return true + } + + var uerr *url.Error + if errors.As(err, &uerr) { + if strings.Contains(strings.ToLower(uerr.Op), "proxy") { + return true + } + if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) { + return true + } + } + + var opErr *net.OpError + if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") { + return true + } + + return false +} + +func isProxyMessage(msg string) bool { + return strings.Contains(msg, "proxyconnect") || + strings.Contains(msg, "proxy error") || + strings.Contains(msg, "proxy authentication required") || + strings.Contains(msg, "proxy: unknown scheme") || + strings.Contains(msg, "socks connect") || + strings.Contains(msg, "socks5") +} diff --git a/errors_classify_test.go b/errors_classify_test.go new file mode 100644 index 0000000..2529cd1 --- /dev/null +++ b/errors_classify_test.go @@ -0,0 +1,116 @@ +package starnet + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "testing" +) + +type timeoutErr struct{} + +func (timeoutErr) Error() string { return "i/o timeout" } +func (timeoutErr) Timeout() bool { return true } +func (timeoutErr) Temporary() bool { return true } + +func TestIsTimeout(t *testing.T) { + if !IsTimeout(context.DeadlineExceeded) { + t.Fatal("context deadline should be timeout") + } + + uerr := &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: timeoutErr{}, + } + if !IsTimeout(uerr) { + t.Fatal("url timeout error should be timeout") + } + + if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) { + t.Fatal("wrapped timeout should be timeout") + } + + if IsTimeout(errors.New("plain error")) { + t.Fatal("plain error must not be timeout") + } +} + +func TestIsDNS(t *testing.T) { + dnsErr := &net.DNSError{ + Err: "no such host", + Name: "example.invalid", + IsNotFound: true, + } + + if !IsDNS(dnsErr) { + t.Fatal("dns error should be dns") + } + + if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) { + t.Fatal("wrapped dns error should be dns") + } + + if !IsDNS(errors.New("lookup example.invalid: no such host")) { + t.Fatal("lookup no such host should be dns") + } + + if IsDNS(errors.New("connection reset by peer")) { + t.Fatal("non dns error should not be dns") + } +} + +func TestIsTLS(t *testing.T) { + tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"} + if !IsTLS(tlsErr) { + t.Fatal("tls record header error should be tls") + } + + if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) { + t.Fatal("wrapped tls error should be tls") + } + + if !IsTLS(errors.New("x509: certificate signed by unknown authority")) { + t.Fatal("x509 error text should be tls") + } + + if !IsTLS(ErrNotTLS) { + t.Fatal("ErrNotTLS should be tls related") + } + + if IsTLS(errors.New("plain error")) { + t.Fatal("plain error should not be tls") + } +} + +func TestIsProxy(t *testing.T) { + raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused") + if !IsProxy(raw) { + t.Fatal("proxyconnect error should be proxy") + } + + uerr := &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: raw, + } + if !IsProxy(uerr) { + t.Fatal("wrapped proxy error should be proxy") + } + + opErr := &net.OpError{ + Op: "proxyconnect", + Net: "tcp", + Err: errors.New("connect failed"), + } + if !IsProxy(opErr) { + t.Fatal("net.OpError proxyconnect should be proxy") + } + + if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) { + t.Fatal("non proxy dial error should not be proxy") + } +} diff --git a/errors_kind_test.go b/errors_kind_test.go new file mode 100644 index 0000000..e5b8cf1 --- /dev/null +++ b/errors_kind_test.go @@ -0,0 +1,49 @@ +package starnet + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "testing" +) + +func TestIsCanceled(t *testing.T) { + if !IsCanceled(context.Canceled) { + t.Fatal("context canceled should be canceled") + } + if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) { + t.Fatal("wrapped context canceled should be canceled") + } + if IsCanceled(context.DeadlineExceeded) { + t.Fatal("deadline exceeded must not be canceled") + } +} + +func TestClassifyError(t *testing.T) { + dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true} + tlsErr := tls.RecordHeaderError{Msg: "bad tls record"} + + tests := []struct { + name string + err error + want ErrorKind + }{ + {name: "nil", err: nil, want: ErrorKindNone}, + {name: "canceled", err: context.Canceled, want: ErrorKindCanceled}, + {name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy}, + {name: "dns", err: dnsErr, want: ErrorKindDNS}, + {name: "tls", err: tlsErr, want: ErrorKindTLS}, + {name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout}, + {name: "other", err: errors.New("boom"), want: ErrorKindOther}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ClassifyError(tt.err); got != tt.want { + t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err) + } + }) + } +} diff --git a/example_test.go b/example_test.go index 6bfdecd..87dd568 100644 --- a/example_test.go +++ b/example_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync/atomic" "time" "b612.me/starnet" @@ -198,3 +199,58 @@ func ExampleWithTimeout() { fmt.Println(resp.StatusCode) // Output: 200 } + +func ExampleWithRetry() { + var hits int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n <= 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := starnet.Get(server.URL, + starnet.WithRetry(2, + starnet.WithRetryBackoff(0, 0, 1), + starnet.WithRetryJitter(0), + ), + ) + if err != nil { + panic(err) + } + defer resp.Close() + + fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits)) + // Output: 200 3 +} + +func ExampleRequest_SetRetry() { + var hits int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost). + SetBodyString("hello"). // 可重放 body 才能安全重试 + SetRetry(1). + SetRetryIdempotentOnly(false). + SetRetryBackoff(0, 0, 1). + SetRetryJitter(0). + Do() + if err != nil { + panic(err) + } + defer resp.Close() + + fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits)) + // Output: 200 2 +} diff --git a/options.go b/options.go index 1b55197..70b29c3 100644 --- a/options.go +++ b/options.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "encoding/json" + "fmt" "io" "net" "net/http" @@ -12,9 +13,9 @@ import ( ) // WithTimeout 设置请求总超时时间 -// timeout > 0: 使用该超时 -// timeout = 0: 使用 Client 默认超时 -// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0) +// timeout > 0: 为本次请求注入 context 超时 +// timeout = 0: 不额外设置请求总超时 +// timeout < 0: 禁用 starnet 默认总超时 func WithTimeout(timeout time.Duration) RequestOpt { return func(r *Request) error { r.config.Network.Timeout = timeout @@ -371,9 +372,23 @@ func WithAutoFetch(auto bool) RequestOpt { } } +// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) +func WithMaxRespBodyBytes(maxBytes int64) RequestOpt { + return func(r *Request) error { + if maxBytes < 0 { + return fmt.Errorf("max response body bytes must be >= 0") + } + r.config.MaxRespBodyBytes = maxBytes + return nil + } +} + // WithRawRequest 设置原始请求 func WithRawRequest(httpReq *http.Request) RequestOpt { return func(r *Request) error { + if httpReq == nil { + return fmt.Errorf("httpReq cannot be nil") + } r.httpReq = httpReq r.doRaw = true return nil diff --git a/ping.go b/ping.go index 2fcb51e..b31be44 100644 --- a/ping.go +++ b/ping.go @@ -1,12 +1,31 @@ package starnet import ( - "bytes" + "context" "encoding/binary" + "errors" + "fmt" "net" + "os" + "strings" + "sync/atomic" "time" ) +const ( + icmpTypeEchoReplyV4 = 0 + icmpTypeEchoRequestV4 = 8 + icmpTypeEchoRequestV6 = 128 + icmpTypeEchoReplyV6 = 129 + + icmpHeaderLen = 8 + icmpReadBufSz = 1500 + + defaultPingAttemptTimeout = 2 * time.Second + defaultPingableCount = 3 + maxPingPayloadSize = 65499 // 65507 - ICMP header(8) +) + type ICMP struct { Type uint8 Code uint8 @@ -15,52 +34,126 @@ type ICMP struct { SequenceNum uint16 } -func getICMP(seq uint16) ICMP { +type pingSocketSpec struct { + network string + family int + requestType uint8 + replyType uint8 +} + +// PingOptions controls ping probing behavior. +type PingOptions struct { + Count int // ping attempts for Pingable, default 3 + Timeout time.Duration // per-attempt timeout, default 2s + Interval time.Duration // delay between attempts, default 0 + Deadline time.Time // overall deadline for Pingable/PingWithContext + PreferIPv4 bool // prefer IPv4 targets + PreferIPv6 bool // prefer IPv6 targets + SourceIP net.IP // optional source IP for raw socket bind + PayloadSize int // ICMP payload bytes, default 0 +} + +type PingResult struct { + Duration time.Duration + RecvCount int + RemoteIP string +} + +var pingIdentifierSeed uint32 + +func nextPingIdentifier() uint16 { + pid := uint32(os.Getpid() & 0xffff) + n := atomic.AddUint32(&pingIdentifierSeed, 1) + return uint16((pid + n) & 0xffff) +} + +func pingPayload(size int) []byte { + if size <= 0 { + return nil + } + payload := make([]byte, size) + for i := 0; i < len(payload); i++ { + payload[i] = byte(i) + } + return payload +} + +func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { icmp := ICMP{ - Type: 8, + Type: typ, Code: 0, CheckSum: 0, - Identifier: 0, + Identifier: identifier, SequenceNum: seq, } - var buffer bytes.Buffer - binary.Write(&buffer, binary.BigEndian, icmp) - icmp.CheckSum = checkSum(buffer.Bytes()) - buffer.Reset() - + buf := marshalICMPPacket(icmp, payload) + icmp.CheckSum = checkSum(buf) return icmp } -func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) { +func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) { var res PingResult + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return res, wrapError(err, "ping context done") + } + if destAddr == nil || destAddr.IP == nil { + return res, fmt.Errorf("destination ip is nil") + } res.RemoteIP = destAddr.String() - conn, err := net.DialIP("ip:icmp", nil, destAddr) + + localAddr, err := localIPAddrForFamily(sourceIP, spec.family) if err != nil { return res, err } - defer conn.Close() - var buffer bytes.Buffer - binary.Write(&buffer, binary.BigEndian, icmp) - if _, err := conn.Write(buffer.Bytes()); err != nil { - return res, err + conn, err := net.DialIP(spec.network, localAddr, destAddr) + if err != nil { + return res, normalizePingDialError(err) + } + defer conn.Close() + + packet := marshalICMPPacket(icmp, payload) + if _, err := conn.Write(packet); err != nil { + return res, wrapError(err, "ping write request") } tStart := time.Now() - - conn.SetReadDeadline((time.Now().Add(timeout))) - - recv := make([]byte, 1024) - res.RecvCount, err = conn.Read(recv) - - if err != nil { - return res, err + deadline := tStart.Add(timeout) + if d, ok := ctx.Deadline(); ok && d.Before(deadline) { + deadline = d + } + if err := conn.SetReadDeadline(deadline); err != nil { + return res, wrapError(err, "ping set read deadline") } - tEnd := time.Now() - res.Duration = tEnd.Sub(tStart) + doneCh := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = conn.SetReadDeadline(time.Now()) + case <-doneCh: + } + }() + defer close(doneCh) - return res, err + recv := make([]byte, icmpReadBufSz) + for { + n, err := conn.Read(recv) + if err != nil { + if ctx.Err() != nil { + return res, wrapError(ctx.Err(), "ping context done") + } + return res, wrapError(err, "ping read reply") + } + if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) { + res.RecvCount = n + res.Duration = time.Since(tStart) + return res, nil + } + } } func checkSum(data []byte) uint16 { @@ -75,36 +168,375 @@ func checkSum(data []byte) uint16 { length -= 2 } if length > 0 { - sum += uint32(data[index]) + sum += uint32(data[index]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) } - sum += (sum >> 16) return uint16(^sum) } -type PingResult struct { - Duration time.Duration - RecvCount int - RemoteIP string +func marshalICMP(icmp ICMP) []byte { + return marshalICMPPacket(icmp, nil) } -func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { - var res PingResult - ipAddr, err := net.ResolveIPAddr("ip", ip) - if err != nil { - return res, err - } - icmp := getICMP(uint16(seq)) - return sendICMPRequest(icmp, ipAddr, timeout) +func marshalICMPPacket(icmp ICMP, payload []byte) []byte { + buf := make([]byte, icmpHeaderLen+len(payload)) + buf[0] = icmp.Type + buf[1] = icmp.Code + binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum) + binary.BigEndian.PutUint16(buf[4:], icmp.Identifier) + binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum) + copy(buf[icmpHeaderLen:], payload) + return buf } -func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { - for i := 0; i < retryLimit; i++ { - _, err := Ping(ip, 29, timeout) - if err != nil { +func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { + for _, off := range candidateICMPOffsets(packet, family) { + if off < 0 || off+icmpHeaderLen > len(packet) { + continue + } + if packet[off] != expectedType || packet[off+1] != 0 { + continue + } + if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier { + continue + } + if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq { continue } return true } return false } + +func candidateICMPOffsets(packet []byte, family int) []int { + offsets := []int{0} + if len(packet) == 0 { + return offsets + } + + ver := packet[0] >> 4 + if ver == 4 && len(packet) >= 20 { + ihl := int(packet[0]&0x0f) * 4 + if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen { + offsets = append(offsets, ihl) + } + } else if ver == 6 && len(packet) >= 40+icmpHeaderLen { + offsets = append(offsets, 40) + } + + // 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。 + if family == 4 && len(packet) >= 20+icmpHeaderLen { + offsets = append(offsets, 20) + } + if family == 6 && len(packet) >= 40+icmpHeaderLen { + offsets = append(offsets, 40) + } + + return dedupOffsets(offsets) +} + +func dedupOffsets(offsets []int) []int { + if len(offsets) <= 1 { + return offsets + } + m := make(map[int]struct{}, len(offsets)) + out := make([]int, 0, len(offsets)) + for _, off := range offsets { + if _, ok := m[off]; ok { + continue + } + m[off] = struct{}{} + out = append(out, off) + } + return out +} + +func socketSpecForIP(ip net.IP) (pingSocketSpec, error) { + if ip == nil { + return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil") + } + + if ip4 := ip.To4(); ip4 != nil { + return pingSocketSpec{ + network: "ip4:icmp", + family: 4, + requestType: icmpTypeEchoRequestV4, + replyType: icmpTypeEchoReplyV4, + }, nil + } + + if ip16 := ip.To16(); ip16 != nil { + return pingSocketSpec{ + network: "ip6:ipv6-icmp", + family: 6, + requestType: icmpTypeEchoRequestV6, + replyType: icmpTypeEchoReplyV6, + }, nil + } + + return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String()) +} + +func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) { + if sourceIP == nil { + return nil, nil + } + if sourceIP.To16() == nil { + return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String()) + } + if family == 4 && sourceIP.To4() == nil { + return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target") + } + if family == 6 && sourceIP.To4() != nil { + return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target") + } + return &net.IPAddr{IP: sourceIP}, nil +} + +func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { + if parsed := net.ParseIP(host); parsed != nil { + return []*net.IPAddr{{IP: parsed}}, nil + } + + var targets []*net.IPAddr + var err4 error + var err6 error + + if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil { + targets = append(targets, ip4) + } else { + err4 = e + } + + if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil { + targets = append(targets, ip6) + } else { + err6 = e + } + + if len(targets) > 0 { + return orderPingTargets(targets, preferIPv4, preferIPv6), nil + } + + if err4 != nil { + return nil, err4 + } + if err6 != nil { + return nil, err6 + } + return nil, ErrPingNoResolvedTarget +} + +func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { + if len(targets) <= 1 || preferIPv4 == preferIPv6 { + return targets + } + + ordered := make([]*net.IPAddr, 0, len(targets)) + if preferIPv4 { + for _, t := range targets { + if t != nil && t.IP != nil && t.IP.To4() != nil { + ordered = append(ordered, t) + } + } + for _, t := range targets { + if t != nil && t.IP != nil && t.IP.To4() == nil { + ordered = append(ordered, t) + } + } + return ordered + } + + for _, t := range targets { + if t != nil && t.IP != nil && t.IP.To4() == nil { + ordered = append(ordered, t) + } + } + for _, t := range targets { + if t != nil && t.IP != nil && t.IP.To4() != nil { + ordered = append(ordered, t) + } + } + return ordered +} + +func normalizePingDialError(err error) error { + if err == nil { + return nil + } + + msg := strings.ToLower(err.Error()) + if errors.Is(err, os.ErrPermission) || + strings.Contains(msg, "operation not permitted") || + strings.Contains(msg, "permission denied") { + return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err) + } + + if strings.Contains(msg, "unknown network") || + strings.Contains(msg, "protocol not available") || + strings.Contains(msg, "address family not supported by protocol") || + strings.Contains(msg, "socket type not supported") { + return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err) + } + + return wrapError(err, "ping dial") +} + +func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) { + out := PingOptions{ + Count: defaultCount, + Timeout: defaultTimeout, + Interval: 0, + PayloadSize: 0, + } + if opts != nil { + out = *opts + if out.Count == 0 { + out.Count = defaultCount + } + if out.Timeout == 0 { + out.Timeout = defaultTimeout + } + } + + if out.Count < 0 { + return out, fmt.Errorf("ping count must be >= 0") + } + if out.Timeout <= 0 { + return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0") + } + if out.Interval < 0 { + return out, fmt.Errorf("ping interval must be >= 0") + } + if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize { + return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize) + } + if out.SourceIP != nil && out.SourceIP.To16() == nil { + return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String()) + } + + return out, nil +} + +func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) { + var res PingResult + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return res, wrapError(err, "ping context done") + } + + targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6) + if err != nil { + return res, wrapError(err, "resolve ping target") + } + + payload := pingPayload(opts.PayloadSize) + var lastErr error + for _, target := range targets { + spec, err := socketSpecForIP(target.IP) + if err != nil { + lastErr = err + continue + } + + icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload) + resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout) + if err == nil { + return resp, nil + } + + // 权限问题通常与地址族无关,继续重试意义不大。 + if errors.Is(err, ErrPingPermissionDenied) { + return res, err + } + lastErr = err + } + + if lastErr != nil { + return res, wrapError(lastErr, "ping all resolved targets failed") + } + return res, ErrPingNoResolvedTarget +} + +// PingWithContext sends one ICMP echo request with context cancel support. +func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) { + opts, err := normalizePingOptions(&PingOptions{ + Count: 1, + Timeout: timeout, + }, 1, timeout) + if err != nil { + return PingResult{}, err + } + + if !opts.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, opts.Deadline) + defer cancel() + } + return pingOnceWithOptions(ctx, host, seq, opts) +} + +// Ping sends one ICMP echo request. +func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { + return PingWithContext(context.Background(), ip, seq, timeout) +} + +// Pingable checks host reachability with retry options. +func Pingable(host string, opts *PingOptions) (bool, error) { + cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout) + if err != nil { + return false, err + } + + ctx := context.Background() + if !cfg.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, cfg.Deadline) + defer cancel() + } + + var lastErr error + for i := 0; i < cfg.Count; i++ { + _, err := pingOnceWithOptions(ctx, host, 29+i, cfg) + if err == nil { + return true, nil + } + lastErr = err + + if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + break + } + + if i < cfg.Count-1 && cfg.Interval > 0 { + timer := time.NewTimer(cfg.Interval) + select { + case <-ctx.Done(): + timer.Stop() + return false, wrapError(ctx.Err(), "pingable context done") + case <-timer.C: + } + } + } + + if lastErr == nil { + lastErr = ErrPingNoResolvedTarget + } + return false, lastErr +} + +// IsIpPingable keeps backward-compatible bool-only behavior. +func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { + if retryLimit <= 0 { + return false + } + ok, _ := Pingable(ip, &PingOptions{ + Count: retryLimit, + Timeout: timeout, + }) + return ok +} diff --git a/ping_logic_test.go b/ping_logic_test.go new file mode 100644 index 0000000..14942c1 --- /dev/null +++ b/ping_logic_test.go @@ -0,0 +1,214 @@ +package starnet + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +func buildICMPPacket(typ uint8, identifier, seq uint16) []byte { + icmp := ICMP{ + Type: typ, + Code: 0, + CheckSum: 0, + Identifier: identifier, + SequenceNum: seq, + } + buf := marshalICMPPacket(icmp, nil) + cs := checkSum(buf) + buf[2] = byte(cs >> 8) + buf[3] = byte(cs) + return buf +} + +func TestNextPingIdentifierChanges(t *testing.T) { + id1 := nextPingIdentifier() + id2 := nextPingIdentifier() + if id1 == id2 { + t.Fatalf("identifier should change between calls: %d == %d", id1, id2) + } +} + +func TestIsExpectedEchoReplyIPv4(t *testing.T) { + identifier := uint16(0x1234) + seq := uint16(0x0102) + reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq) + + if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) { + t.Fatal("expected IPv4 echo reply to match") + } + if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) { + t.Fatal("mismatched sequence should not match") + } +} + +func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) { + identifier := uint16(0x1111) + seq := uint16(0x2222) + + ipHeader := make([]byte, 20) + ipHeader[0] = 0x45 // version=4, ihl=5 + + reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq) + packet := append(ipHeader, reply...) + + if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) { + t.Fatal("expected IPv4 packet with header to match") + } +} + +func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) { + identifier := uint16(0xabcd) + seq := uint16(0x00ff) + + ipv6Header := make([]byte, 40) + ipv6Header[0] = 0x60 // version=6 + + reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq) + packet := append(ipv6Header, reply...) + + if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) { + t.Fatal("expected IPv6 packet with header to match") + } +} + +func TestPingInvalidTimeout(t *testing.T) { + _, err := Ping("127.0.0.1", 1, 0) + if err == nil { + t.Fatal("expected error for non-positive timeout") + } + if !errors.Is(err, ErrPingInvalidTimeout) { + t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err) + } +} + +func TestIsIPPingableInvalidRetry(t *testing.T) { + if IsIpPingable("127.0.0.1", time.Millisecond, 0) { + t.Fatal("retryLimit=0 should return false") + } +} + +func TestSocketSpecForIP(t *testing.T) { + v4, err := socketSpecForIP(net.ParseIP("127.0.0.1")) + if err != nil { + t.Fatalf("unexpected v4 error: %v", err) + } + if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 { + t.Fatalf("unexpected v4 spec: %+v", v4) + } + + v6, err := socketSpecForIP(net.ParseIP("::1")) + if err != nil { + t.Fatalf("unexpected v6 error: %v", err) + } + if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 { + t.Fatalf("unexpected v6 spec: %+v", v6) + } + + _, err = socketSpecForIP(nil) + if err == nil { + t.Fatal("expected error for nil ip") + } + if !errors.Is(err, ErrInvalidIP) { + t.Fatalf("expected ErrInvalidIP, got: %v", err) + } +} + +func TestResolvePingTargetsLiteral(t *testing.T) { + v4, err := resolvePingTargets("127.0.0.1", false, false) + if err != nil { + t.Fatalf("unexpected v4 resolve error: %v", err) + } + if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil { + t.Fatalf("unexpected v4 targets: %+v", v4) + } + + v6, err := resolvePingTargets("::1", false, false) + if err != nil { + t.Fatalf("unexpected v6 resolve error: %v", err) + } + if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil { + t.Fatalf("unexpected v6 targets: %+v", v6) + } +} + +func TestNormalizePingDialError(t *testing.T) { + perr := normalizePingDialError(errors.New("socket: operation not permitted")) + if !errors.Is(perr, ErrPingPermissionDenied) { + t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr) + } + + uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp")) + if !errors.Is(uerr, ErrPingProtocolUnsupported) { + t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr) + } +} + +func TestOrderPingTargets(t *testing.T) { + targets := []*net.IPAddr{ + {IP: net.ParseIP("::1")}, + {IP: net.ParseIP("127.0.0.1")}, + } + + v4First := orderPingTargets(targets, true, false) + if v4First[0].IP.To4() == nil { + t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP) + } + + v6First := orderPingTargets(targets, false, true) + if v6First[0].IP.To4() != nil { + t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP) + } +} + +func TestNormalizePingOptions(t *testing.T) { + opts, err := normalizePingOptions(nil, 3, 2*time.Second) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if opts.Count != 3 || opts.Timeout != 2*time.Second { + t.Fatalf("unexpected defaults: %+v", opts) + } + + _, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second) + if err == nil { + t.Fatal("expected error for negative count") + } + + _, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second) + if err == nil { + t.Fatal("expected error for negative timeout") + } + + _, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second) + if err == nil { + t.Fatal("expected error for too large payload") + } +} + +func TestPingWithContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second) + if err == nil { + t.Fatal("expected canceled error") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got: %v", err) + } +} + +func TestPingableInvalidOptions(t *testing.T) { + _, err := Pingable("127.0.0.1", &PingOptions{Count: -1}) + if err == nil { + t.Fatal("expected invalid count error") + } + + _, err = Pingable("127.0.0.1", &PingOptions{Interval: -1}) + if err == nil { + t.Fatal("expected invalid interval error") + } +} diff --git a/request.go b/request.go index 486b9ef..9484cd5 100644 --- a/request.go +++ b/request.go @@ -12,6 +12,7 @@ import ( type Request struct { ctx context.Context execCtx context.Context // 执行时的 context(注入了配置) + cancel context.CancelFunc url string method string err error // 累积的错误 @@ -20,6 +21,7 @@ type Request struct { client *Client httpClient *http.Client httpReq *http.Request + retry *retryPolicy applied bool // 是否已应用配置 doRaw bool // 是否使用原始请求(不修改) @@ -160,6 +162,7 @@ func (r *Request) Clone() *Request { config: r.config.Clone(), client: r.client, httpClient: r.httpClient, + retry: cloneRetryPolicy(r.retry), applied: false, // 重置应用状态 doRaw: r.doRaw, autoFetch: r.autoFetch, @@ -318,6 +321,14 @@ func (r *Request) Do() (*Response, error) { return nil, r.err } + if r.hasRetryPolicy() { + return r.doWithRetry() + } + + return r.doOnce() +} + +func (r *Request) doOnce() (*Response, error) { // 准备请求 if err := r.prepare(); err != nil { return nil, wrapError(err, "prepare request") @@ -326,6 +337,10 @@ func (r *Request) Do() (*Response, error) { // 执行请求 httpResp, err := r.httpClient.Do(r.httpReq) if err != nil { + if r.cancel != nil { + r.cancel() + r.cancel = nil + } return &Response{ Response: &http.Response{}, request: r, @@ -334,19 +349,33 @@ func (r *Request) Do() (*Response, error) { }, wrapError(err, "do request") } + rawBody := httpResp.Body + if r.cancel != nil { + rawBody = &cancelReadCloser{ + ReadCloser: httpResp.Body, + cancel: r.cancel, + } + } + // 创建响应 resp := &Response{ Response: httpResp, request: r, httpClient: r.httpClient, + cancel: r.cancel, body: &Body{ - raw: httpResp.Body, + raw: rawBody, + maxBytes: r.config.MaxRespBodyBytes, }, } + r.cancel = nil // 自动获取响应体 if r.autoFetch { - resp.body.readAll() + if err := resp.body.readAll(); err != nil { + _ = resp.Close() + return resp, err + } } return resp, nil @@ -371,3 +400,28 @@ func (r *Request) Put() (*Response, error) { func (r *Request) Delete() (*Response, error) { return r.SetMethod(http.MethodDelete).Do() } + +// Head 发送 HEAD 请求 +func (r *Request) Head() (*Response, error) { + return r.SetMethod(http.MethodHead).Do() +} + +// Patch 发送 PATCH 请求 +func (r *Request) Patch() (*Response, error) { + return r.SetMethod(http.MethodPatch).Do() +} + +// Options 发送 OPTIONS 请求 +func (r *Request) Options() (*Response, error) { + return r.SetMethod(http.MethodOptions).Do() +} + +// Trace 发送 TRACE 请求 +func (r *Request) Trace() (*Response, error) { + return r.SetMethod(http.MethodTrace).Do() +} + +// Connect 发送 CONNECT 请求 +func (r *Request) Connect() (*Response, error) { + return r.SetMethod(http.MethodConnect).Do() +} diff --git a/request_body.go b/request_body.go index 120d399..63e5ea3 100644 --- a/request_body.go +++ b/request_body.go @@ -2,7 +2,9 @@ package starnet import ( "bytes" + "context" "encoding/json" + "fmt" "io" "mime/multipart" "net/http" @@ -65,7 +67,7 @@ func (r *Request) SetFormData(data map[string][]string) *Request { if r.doRaw { return r } - r.config.Body.FormData = data + r.config.Body.FormData = cloneStringMapSlice(data) return r } @@ -345,121 +347,102 @@ func (r *Request) prepare() error { return err // ← 失败时不设置 applied } } - // 原始模式不修改请求内容 - if r.doRaw { - r.applied = true - return nil + + if r.httpReq == nil { + return fmt.Errorf("http request is nil") } - // 应用查询参数 - if len(r.config.Queries) > 0 { - q := r.httpReq.URL.Query() - for k, values := range r.config.Queries { + // 原始模式不修改请求内容 + if !r.doRaw { + // 应用查询参数 + if len(r.config.Queries) > 0 { + q := r.httpReq.URL.Query() + for k, values := range r.config.Queries { + for _, v := range values { + q.Add(k, v) + } + } + r.httpReq.URL.RawQuery = q.Encode() + } + + // 应用 Headers + for k, values := range r.config.Headers { for _, v := range values { - q.Add(k, v) + r.httpReq.Header.Add(k, v) } } - r.httpReq.URL.RawQuery = q.Encode() - } - // 应用 Headers - for k, values := range r.config.Headers { - for _, v := range values { - r.httpReq.Header.Add(k, v) + // 应用 Cookies + for _, cookie := range r.config.Cookies { + r.httpReq.AddCookie(cookie) + } + + // 应用 Basic Auth + if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { + r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) + } + + // 应用请求体 + if err := r.applyBody(); err != nil { + return err + } + + // 应用 Content-Length + if r.config.ContentLength > 0 { + r.httpReq.ContentLength = r.config.ContentLength + } else if r.config.ContentLength < 0 { + r.httpReq.ContentLength = 0 + } + + // 自动计算 Content-Length + if r.config.AutoCalcContentLength && r.httpReq.Body != nil { + data, err := io.ReadAll(r.httpReq.Body) + if err != nil { + return wrapError(err, "read body for content length") + } + r.httpReq.ContentLength = int64(len(data)) + r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data)) + } + + // 设置 TLS ServerName(如果有 TLS Config) + if r.config.TLS.Config != nil && r.httpReq.URL != nil { + r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() } } - // 应用 Cookies - for _, cookie := range r.config.Cookies { - r.httpReq.AddCookie(cookie) + execCtx := r.ctx + if !r.doRaw { + // raw 模式下不注入请求级网络配置,只应用 context/超时。 + execCtx = injectRequestConfig(execCtx, r.config) } - // 应用 Basic Auth - if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { - r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) + // 请求级总超时通过 context 控制,避免污染共享 http.Client。 + if r.config.Network.Timeout > 0 { + execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout) } - // 应用请求体 - if err := r.applyBody(); err != nil { - return err - } - - // 应用 Content-Length - if r.config.ContentLength > 0 { - r.httpReq.ContentLength = r.config.ContentLength - } else if r.config.ContentLength < 0 { - r.httpReq.ContentLength = 0 - } - - // 自动计算 Content-Length - if r.config.AutoCalcContentLength && r.httpReq.Body != nil { - data, err := io.ReadAll(r.httpReq.Body) - if err != nil { - return wrapError(err, "read body for content length") - } - r.httpReq.ContentLength = int64(len(data)) - r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data)) - } - - // 设置 TLS ServerName(如果有 TLS Config) - if r.config.TLS.Config != nil && r.httpReq.URL != nil { - r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() - } - - // 注入配置到 context - r.execCtx = injectRequestConfig(r.ctx, r.config) + r.execCtx = execCtx r.httpReq = r.httpReq.WithContext(r.execCtx) - + r.applied = true return nil } // buildHTTPClient 构建 HTTP Client func (r *Request) buildHTTPClient() (*http.Client, error) { - applyTimeoutOverride := func(base *http.Client) *http.Client { - // 没有 base 时兜底 - if base == nil { - base = &http.Client{} - } - - rt := r.config.Network.Timeout - - // 语义: - // rt < 0 : 本次请求禁用超时(Timeout = 0) - // rt = 0 : 沿用 base.Timeout - // rt > 0 : 本次请求超时覆盖 - if rt == 0 { - return base - } - - clone := &http.Client{ - Transport: base.Transport, - CheckRedirect: base.CheckRedirect, - Jar: base.Jar, - } - - if rt < 0 { - clone.Timeout = 0 - } else { - clone.Timeout = rt - } - return clone - } - // 优先使用请求关联的 Client if r.client != nil { - return applyTimeoutOverride(r.client.HTTPClient()), nil + return r.client.HTTPClient(), nil } // 自定义 Transport if r.config.CustomTransport && r.config.Transport != nil { - base := &http.Client{ + return &http.Client{ Transport: &Transport{base: r.config.Transport}, Timeout: 0, - } - return applyTimeoutOverride(base), nil + }, nil } // 默认全局 client - return applyTimeoutOverride(DefaultHTTPClient()), nil + return DefaultHTTPClient(), nil } diff --git a/request_config.go b/request_config.go index fed3fe0..195ef22 100644 --- a/request_config.go +++ b/request_config.go @@ -10,9 +10,9 @@ import ( ) // SetTimeout 设置请求总超时时间 -// timeout > 0: 使用该超时 -// timeout = 0: 使用 Client 默认超时 -// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0) +// timeout > 0: 为本次请求注入 context 超时 +// timeout = 0: 不额外设置请求总超时 +// timeout < 0: 禁用 starnet 默认总超时 func (r *Request) SetTimeout(timeout time.Duration) *Request { if r.err != nil { return r @@ -194,6 +194,19 @@ func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request { return r } +// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制) +func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request { + if r.err != nil { + return r + } + if maxBytes < 0 { + r.err = fmt.Errorf("max response body bytes must be >= 0") + return r + } + r.config.MaxRespBodyBytes = maxBytes + return r +} + // AddQuery 添加查询参数 func (r *Request) AddQuery(key, value string) *Request { if r.err != nil { @@ -217,7 +230,7 @@ func (r *Request) SetQueries(queries map[string][]string) *Request { if r.err != nil { return r } - r.config.Queries = queries + r.config.Queries = cloneStringMapSlice(queries) return r } diff --git a/request_header.go b/request_header.go index 55f48be..68ae11e 100644 --- a/request_header.go +++ b/request_header.go @@ -36,7 +36,7 @@ func (r *Request) SetHeaders(headers http.Header) *Request { if r.doRaw { return r } - r.config.Headers = headers + r.config.Headers = cloneHeader(headers) return r } diff --git a/request_methods_ext_test.go b/request_methods_ext_test.go new file mode 100644 index 0000000..e8fea2e --- /dev/null +++ b/request_methods_ext_test.go @@ -0,0 +1,43 @@ +package starnet + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestConvenienceMethods(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Method", r.Method) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + tests := []struct { + name string + method string + do func(r *Request) (*Response, error) + }{ + {name: "Head", method: http.MethodHead, do: (*Request).Head}, + {name: "Patch", method: http.MethodPatch, do: (*Request).Patch}, + {name: "Options", method: http.MethodOptions, do: (*Request).Options}, + {name: "Trace", method: http.MethodTrace, do: (*Request).Trace}, + {name: "Connect", method: http.MethodConnect, do: (*Request).Connect}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := NewSimpleRequest(server.URL, http.MethodGet) + resp, err := tt.do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if got := resp.Header.Get("X-Method"); got != tt.method { + t.Fatalf("method=%s want=%s", got, tt.method) + } + }) + } +} diff --git a/response.go b/response.go index d43d1ee..3942f7e 100644 --- a/response.go +++ b/response.go @@ -13,6 +13,7 @@ type Response struct { *http.Response request *Request httpClient *http.Client + cancel func() body *Body } @@ -21,9 +22,26 @@ type Body struct { raw io.ReadCloser data []byte consumed bool + maxBytes int64 mu sync.Mutex } +type cancelReadCloser struct { + io.ReadCloser + cancel func() + once sync.Once +} + +func (c *cancelReadCloser) Close() error { + err := c.ReadCloser.Close() + c.once.Do(func() { + if c.cancel != nil { + c.cancel() + } + }) + return err +} + // Request 获取原始请求 func (r *Response) Request() *Request { return r.request @@ -42,6 +60,10 @@ func (r *Response) Close() error { if r.body != nil && r.body.raw != nil { return r.body.raw.Close() } + if r.cancel != nil { + r.cancel() + r.cancel = nil + } return nil } @@ -70,14 +92,24 @@ func (b *Body) readAll() error { return nil } - data, err := io.ReadAll(b.raw) + reader := io.Reader(b.raw) + if b.maxBytes > 0 { + reader = io.LimitReader(b.raw, b.maxBytes+1) + } + + data, err := io.ReadAll(reader) if err != nil { return wrapError(err, "read response body") } + if b.maxBytes > 0 && int64(len(data)) > b.maxBytes { + b.consumed = true + _ = b.raw.Close() + return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes) + } b.data = data b.consumed = true - b.raw.Close() + _ = b.raw.Close() return nil } diff --git a/response_limit_test.go b/response_limit_test.go new file mode 100644 index 0000000..1008f83 --- /dev/null +++ b/response_limit_test.go @@ -0,0 +1,84 @@ +package starnet + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWithMaxRespBodyBytes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("123456")) + })) + defer server.Close() + + resp, err := Get(server.URL, WithMaxRespBodyBytes(4)) + if err != nil { + t.Fatalf("unexpected request error: %v", err) + } + defer resp.Close() + + _, err = resp.Body().Bytes() + if err == nil { + t.Fatal("expected body too large error") + } + if !errors.Is(err, ErrRespBodyTooLarge) { + t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err) + } +} + +func TestSetMaxRespBodyBytes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("1234")) + })) + defer server.Close() + + resp, err := NewSimpleRequest(server.URL, http.MethodGet). + SetMaxRespBodyBytes(4). + Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + body, err := resp.Body().String() + if err != nil { + t.Fatalf("unexpected read error: %v", err) + } + if body != "1234" { + t.Fatalf("body=%q want=1234", body) + } +} + +func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("123456")) + })) + defer server.Close() + + _, err := NewSimpleRequest(server.URL, http.MethodGet). + SetAutoFetch(true). + SetMaxRespBodyBytes(4). + Do() + if err == nil { + t.Fatal("expected body too large error with auto fetch") + } + if !errors.Is(err, ErrRespBodyTooLarge) { + t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err) + } +} + +func TestSetMaxRespBodyBytesInvalid(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1) + if req.Err() == nil { + t.Fatal("expected error for negative max bytes") + } +} + +func TestWithMaxRespBodyBytesInvalid(t *testing.T) { + _, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1)) + if err == nil { + t.Fatal("expected error for negative max bytes") + } +} diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..3e5e415 --- /dev/null +++ b/retry.go @@ -0,0 +1,423 @@ +package starnet + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "math/rand" + "net" + "net/http" + "time" +) + +type RetryOpt func(*retryPolicy) error + +type retryPolicy struct { + maxRetries int + baseDelay time.Duration + maxDelay time.Duration + factor float64 + jitter float64 + idempotentOnly bool + statuses map[int]struct{} + onError func(error) bool +} + +func cloneRetryPolicy(p *retryPolicy) *retryPolicy { + if p == nil { + return nil + } + cloned := &retryPolicy{ + maxRetries: p.maxRetries, + baseDelay: p.baseDelay, + maxDelay: p.maxDelay, + factor: p.factor, + jitter: p.jitter, + idempotentOnly: p.idempotentOnly, + onError: p.onError, + } + if p.statuses != nil { + cloned.statuses = make(map[int]struct{}, len(p.statuses)) + for code := range p.statuses { + cloned.statuses[code] = struct{}{} + } + } + return cloned +} + +func defaultRetryPolicy(max int) *retryPolicy { + return &retryPolicy{ + maxRetries: max, + baseDelay: 100 * time.Millisecond, + maxDelay: 2 * time.Second, + factor: 2.0, + jitter: 0.1, + idempotentOnly: true, + statuses: map[int]struct{}{ + http.StatusRequestTimeout: {}, + http.StatusTooEarly: {}, + http.StatusTooManyRequests: {}, + http.StatusInternalServerError: {}, + http.StatusBadGateway: {}, + http.StatusServiceUnavailable: {}, + http.StatusGatewayTimeout: {}, + }, + } +} + +func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) { + if max < 0 { + return nil, fmt.Errorf("max retry must be >= 0") + } + if max == 0 { + return nil, nil + } + + policy := defaultRetryPolicy(max) + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt(policy); err != nil { + return nil, err + } + } + return policy, nil +} + +func WithRetry(max int, opts ...RetryOpt) RequestOpt { + return func(r *Request) error { + policy, err := buildRetryPolicy(max, opts...) + if err != nil { + return err + } + r.retry = policy + return nil + } +} + +func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request { + if r.err != nil { + return r + } + policy, err := buildRetryPolicy(max, opts...) + if err != nil { + r.err = err + return r + } + r.retry = policy + return r +} + +func (r *Request) DisableRetry() *Request { + if r.err != nil { + return r + } + r.retry = nil + return r +} + +func (r *Request) applyRetryOpt(opt RetryOpt) *Request { + if r.err != nil { + return r + } + if opt == nil { + return r + } + if r.retry == nil { + r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first") + return r + } + if err := opt(r.retry); err != nil { + r.err = err + } + return r +} + +func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request { + return r.applyRetryOpt(WithRetryBackoff(base, max, factor)) +} + +func (r *Request) SetRetryJitter(ratio float64) *Request { + return r.applyRetryOpt(WithRetryJitter(ratio)) +} + +func (r *Request) SetRetryStatuses(codes ...int) *Request { + return r.applyRetryOpt(WithRetryStatuses(codes...)) +} + +func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request { + return r.applyRetryOpt(WithRetryIdempotentOnly(enabled)) +} + +func (r *Request) SetRetryOnError(fn func(error) bool) *Request { + return r.applyRetryOpt(WithRetryOnError(fn)) +} + +func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt { + return func(p *retryPolicy) error { + if base < 0 { + return fmt.Errorf("retry base delay must be >= 0") + } + if max < 0 { + return fmt.Errorf("retry max delay must be >= 0") + } + if factor <= 0 { + return fmt.Errorf("retry factor must be > 0") + } + p.baseDelay = base + p.maxDelay = max + p.factor = factor + return nil + } +} + +func WithRetryJitter(ratio float64) RetryOpt { + return func(p *retryPolicy) error { + if ratio < 0 || ratio > 1 { + return fmt.Errorf("retry jitter ratio must be in [0,1]") + } + p.jitter = ratio + return nil + } +} + +func WithRetryStatuses(codes ...int) RetryOpt { + return func(p *retryPolicy) error { + statuses := make(map[int]struct{}, len(codes)) + for _, code := range codes { + if code < 100 || code > 999 { + return fmt.Errorf("invalid retry status code: %d", code) + } + statuses[code] = struct{}{} + } + p.statuses = statuses + return nil + } +} + +func WithRetryIdempotentOnly(enabled bool) RetryOpt { + return func(p *retryPolicy) error { + p.idempotentOnly = enabled + return nil + } +} + +func WithRetryOnError(fn func(error) bool) RetryOpt { + return func(p *retryPolicy) error { + p.onError = fn + return nil + } +} + +func (r *Request) hasRetryPolicy() bool { + return r.retry != nil && r.retry.maxRetries > 0 +} + +func (r *Request) doWithRetry() (*Response, error) { + policy := cloneRetryPolicy(r.retry) + if policy == nil || policy.maxRetries <= 0 { + return r.doOnce() + } + + if !policy.canRetryRequest(r) { + return r.doOnce() + } + + retryCtx := r.ctx + retryCancel := func() {} + if r.config.Network.Timeout > 0 { + retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout) + } + defer retryCancel() + + maxAttempts := policy.maxRetries + 1 + var lastResp *Response + var lastErr error + + for attempt := 0; attempt < maxAttempts; attempt++ { + attemptReq, err := r.newRetryAttempt(retryCtx) + if err != nil { + return nil, wrapError(err, "build retry attempt") + } + + resp, err := attemptReq.doOnce() + if resp != nil { + resp.request = r + } + + if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) { + return resp, err + } + + lastResp = resp + lastErr = err + if lastResp != nil { + _ = lastResp.Close() + } + + delay := policy.nextDelay(attempt) + if delay <= 0 { + continue + } + + timer := time.NewTimer(delay) + select { + case <-retryCtx.Done(): + timer.Stop() + return lastResp, wrapError(retryCtx.Err(), "retry context done") + case <-timer.C: + } + } + + return lastResp, lastErr +} + +func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) { + attempt := r.Clone() + attempt.retry = nil + attempt.cancel = nil + attempt.applied = false + attempt.execCtx = nil + attempt.ctx = ctx + + // 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。 + if attempt.config != nil && attempt.config.Network.Timeout > 0 { + attempt.config.Network.Timeout = 0 + } + + if !attempt.doRaw { + attempt.httpReq = attempt.httpReq.WithContext(ctx) + return attempt, nil + } + + if r.httpReq == nil { + return nil, fmt.Errorf("http request is nil") + } + + raw := r.httpReq.Clone(ctx) + if r.httpReq.GetBody != nil { + body, err := r.httpReq.GetBody() + if err != nil { + return nil, wrapError(err, "get raw request body") + } + raw.Body = body + } else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody { + return nil, fmt.Errorf("raw request body is not replayable") + } + + attempt.httpReq = raw + return attempt, nil +} + +func (p *retryPolicy) canRetryRequest(r *Request) bool { + if p.idempotentOnly && !isIdempotentMethod(r.method) { + return false + } + return isReplayableRequest(r) +} + +func isIdempotentMethod(method string) bool { + switch method { + case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace: + return true + default: + return false + } +} + +func isReplayableRequest(r *Request) bool { + if r == nil { + return false + } + + if r.doRaw { + if r.httpReq == nil { + return false + } + if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody { + return true + } + return r.httpReq.GetBody != nil + } + + if r.config == nil { + return false + } + + // Reader / stream body 通常不可重放,保守地不重试。 + if r.config.Body.Reader != nil { + return false + } + + for _, f := range r.config.Body.Files { + if f.FileData != nil || f.FilePath == "" { + return false + } + } + + return true +} + +func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool { + if attempt >= maxAttempts-1 { + return false + } + if ctx != nil && ctx.Err() != nil { + return false + } + + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + if p.onError != nil { + return p.onError(err) + } + return isRetryableError(err) + } + + if resp == nil || resp.Response == nil { + return false + } + _, ok := p.statuses[resp.StatusCode] + return ok +} + +func isRetryableError(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + return true + } + if netErr.Temporary() { + return true + } + } + return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) +} + +func (p *retryPolicy) nextDelay(attempt int) time.Duration { + if p.baseDelay <= 0 { + return 0 + } + + delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt))) + if p.maxDelay > 0 && delay > p.maxDelay { + delay = p.maxDelay + } + + if p.jitter <= 0 { + return delay + } + + low := 1 - p.jitter + if low < 0 { + low = 0 + } + high := 1 + p.jitter + scale := low + rand.Float64()*(high-low) + return time.Duration(float64(delay) * scale) +} diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 0000000..7f1ffe6 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,298 @@ +package starnet + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestWithRetrySmokeGet(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n <= 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer s.Close() + + resp, err := Get(s.URL, + WithRetry(2, + WithRetryBackoff(0, 0, 1), + WithRetryJitter(0), + ), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) + } + if atomic.LoadInt32(&hits) != 3 { + t.Fatalf("hits=%d want=3", hits) + } +} + +func TestWithRetryResponseRequestPointerStable(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodGet). + SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.Request() != req { + t.Fatal("response request pointer should point to original request") + } +} + +func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodPost). + SetBodyReader(strings.NewReader("payload")). + SetRetry(3, + WithRetryIdempotentOnly(false), + WithRetryBackoff(0, 0, 1), + WithRetryJitter(0), + ) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable) + } + if atomic.LoadInt32(&hits) != 1 { + t.Fatalf("hits=%d want=1", hits) + } +} + +func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + _, _ = io.Copy(io.Discard, r.Body) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodPost). + SetBodyString("hello"). + SetRetry(1, + WithRetryIdempotentOnly(false), + WithRetryBackoff(0, 0, 1), + WithRetryJitter(0), + ) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) + } + if atomic.LoadInt32(&hits) != 2 { + t.Fatalf("hits=%d want=2", hits) + } +} + +func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer s.Close() + + rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw"))) + if rawReq.GetBody != nil { + t.Fatal("raw request GetBody should be nil in this test") + } + + req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)). + SetRetry(2, + WithRetryIdempotentOnly(false), + WithRetryBackoff(0, 0, 1), + WithRetryJitter(0), + ) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if atomic.LoadInt32(&hits) != 1 { + t.Fatalf("hits=%d want=1", hits) + } +} + +func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + time.Sleep(80 * time.Millisecond) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodGet). + SetTimeout(120*time.Millisecond). + SetRetry(3, + WithRetryBackoff(0, 0, 1), + WithRetryJitter(0), + ) + + _, err := req.Do() + if err == nil { + t.Fatal("expected timeout error") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context deadline exceeded, got: %v", err) + } + if h := atomic.LoadInt32(&hits); h > 2 { + t.Fatalf("hits=%d want<=2 under tight timeout budget", h) + } +} + +func TestSetRetryInvalidMax(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1) + if req.Err() == nil { + t.Fatal("expected error for negative retry max") + } +} + +func TestSetRetrySeriesSmokeGet(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n <= 2 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodGet). + SetRetry(2). + SetRetryBackoff(0, 0, 1). + SetRetryJitter(0). + SetRetryStatuses(http.StatusTooManyRequests) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) + } + if h := atomic.LoadInt32(&hits); h != 3 { + t.Fatalf("hits=%d want=3", h) + } +} + +func TestSetRetryIdempotentOnlyWithPost(t *testing.T) { + var hits int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&hits, 1) + _, _ = io.Copy(io.Discard, r.Body) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + req := NewSimpleRequest(s.URL, http.MethodPost). + SetBodyString("hello"). + SetRetry(1). + SetRetryIdempotentOnly(false). + SetRetryBackoff(0, 0, 1). + SetRetryJitter(0) + resp, err := req.Do() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) + } + if h := atomic.LoadInt32(&hits); h != 2 { + t.Fatalf("hits=%d want=2", h) + } +} + +func TestSetRetryOnErrorOverridesDefault(t *testing.T) { + var dials int32 + dialErr := errors.New("dial failed") + + req := NewSimpleRequest("http://example.com", http.MethodGet). + SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + atomic.AddInt32(&dials, 1) + return nil, dialErr + }). + SetRetry(1). + SetRetryBackoff(0, 0, 1). + SetRetryJitter(0). + SetRetryOnError(func(err error) bool { + return true + }) + + _, err := req.Do() + if err == nil { + t.Fatal("expected error") + } + if h := atomic.LoadInt32(&dials); h != 2 { + t.Fatalf("dial attempts=%d want=2", h) + } +} + +func TestSetRetryOptionRequireEnableRetry(t *testing.T) { + req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2) + if req.Err() == nil { + t.Fatal("expected error when setting retry options before SetRetry") + } + if !strings.Contains(req.Err().Error(), "call SetRetry first") { + t.Fatalf("unexpected error: %v", req.Err()) + } +} diff --git a/timeout_refactor_test.go b/timeout_refactor_test.go new file mode 100644 index 0000000..8abda61 --- /dev/null +++ b/timeout_refactor_test.go @@ -0,0 +1,115 @@ +package starnet + +import ( + "io" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + client := NewClientNoErr() + baseTimeout := client.HTTPClient().Timeout + + _, err := client.Get(s.URL, WithTimeout(80*time.Millisecond)) + if err == nil { + t.Fatal("expected request timeout error") + } + + if client.HTTPClient().Timeout != baseTimeout { + t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout) + } + + resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond)) + if err != nil { + t.Fatalf("second request should succeed with larger timeout: %v", err) + } + defer resp.Close() +} + +func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer s.Close() + + client := NewClientNoErr() + + resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond)) + if err != nil { + t.Fatalf("first request should succeed: %v", err) + } + _ = resp1.Close() + + // 如果请求超时依赖连接级绝对 deadline,经过该等待后复用连接会出现误超时。 + time.Sleep(220 * time.Millisecond) + + resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second)) + if err != nil { + t.Fatalf("second request should not be affected by previous timeout window: %v", err) + } + _ = resp2.Close() +} + +func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen error: %v", err) + } + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + + time.Sleep(180 * time.Millisecond) + _, _ = conn.Write([]byte("x")) + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial error: %v", err) + } + defer c.Close() + + if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil { + t.Fatalf("set read deadline error: %v", err) + } + + buf := make([]byte, 1) + _, err = c.Read(buf) + if err == nil { + t.Fatal("expected read timeout error") + } + if ne, ok := err.(net.Error); !ok || !ne.Timeout() { + t.Fatalf("expected net timeout error, got: %v", err) + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + t.Fatalf("clear read deadline error: %v", err) + } + + if _, err := io.ReadFull(c, buf); err != nil { + t.Fatalf("read after clearing deadline should succeed: %v", err) + } + if string(buf) != "x" { + t.Fatalf("unexpected payload: %q", string(buf)) + } + + <-done +} diff --git a/types.go b/types.go index 11c5e0f..425e763 100644 --- a/types.go +++ b/types.go @@ -84,6 +84,7 @@ type RequestConfig struct { BasicAuth [2]string // Basic 认证 ContentLength int64 // 手动设置的 Content-Length AutoCalcContentLength bool // 自动计算 Content-Length + MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制) UploadProgress UploadProgressFunc // 上传进度回调 // Transport 配置 @@ -121,6 +122,7 @@ func (c *RequestConfig) Clone() *RequestConfig { BasicAuth: c.BasicAuth, ContentLength: c.ContentLength, AutoCalcContentLength: c.AutoCalcContentLength, + MaxRespBodyBytes: c.MaxRespBodyBytes, UploadProgress: c.UploadProgress, CustomTransport: c.CustomTransport, Transport: c.Transport, // Transport 共享