1. 优化ping功能

2. 新增重试机制
3. 优化错误处理逻辑
This commit is contained in:
兔子 2026-03-19 16:42:45 +08:00
parent 4568e17f06
commit b5bd7595a1
Signed by: b612
GPG Key ID: 99DD2222B612B612
24 changed files with 2645 additions and 162 deletions

3
.gitignore vendored
View File

@ -1 +1,4 @@
.idea .idea
.sentrux/
agent_readme.md
target.md

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2026 starnet contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

106
README.md Normal file
View File

@ -0,0 +1,106 @@
# starnet
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities.
## Highlights
- Request-level timeout by context (without mutating shared `http.Client` timeout)
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses
- Response body safety guard via max body bytes limit
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`)
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios
- ICMP ping with IPv4/IPv6 target handling and option-based probing API
## Main Features
### HTTP Client and Request
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect`
- Request body helpers: JSON, form data, multipart file upload, stream body
- Header/cookie/query helpers with defensive copy on key setters
- Request cloning for safe reuse in concurrent or variant calls
### Timeout and Retry
- Request timeout is applied by context deadline, not global client timeout
- Retry supports:
- max attempts
- backoff factor/base/max
- jitter
- retry status whitelist
- idempotent-only guard
- custom retry-on-error callback
- Retry keeps original request pointer in final response for consistency
### Response Handling
- `Bytes/String/JSON/Reader` helpers
- optional auto-fetch mode
- configurable max response body bytes to prevent oversized reads
### Ping Module
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable`
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size
- explicit error semantics for permission/protocol/timeout/resolve failures
## Install
```bash
go get b612.me/starnet
```
## Quick Example
```go
package main
import (
"fmt"
"net/http"
"time"
"b612.me/starnet"
)
func main() {
resp, err := starnet.Get(
"https://example.com",
starnet.WithTimeout(2*time.Second),
starnet.WithRetry(2,
starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2),
starnet.WithRetryJitter(0.1),
),
starnet.WithMaxRespBodyBytes(1<<20),
)
if err != nil {
fmt.Println("request failed:", starnet.ClassifyError(err), err)
return
}
defer resp.Close()
fmt.Println("status:", resp.StatusCode)
_, _ = resp.Body().Bytes()
ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{
Count: 2,
Timeout: 2 * time.Second,
})
fmt.Println("pingable:", ok, pingErr == nil)
_ = http.MethodGet
}
```
## Stability Notes
- Raw ICMP ping may require elevated privileges on some systems.
- Integration tests that rely on external network are environment-dependent.
## License
This project is licensed under the Apache License 2.0.
See [LICENSE](./LICENSE).

View File

@ -78,7 +78,6 @@ func needsDynamicTransport(rc *RequestContext) bool {
rc.Proxy != "" || rc.Proxy != "" ||
rc.DialFn != nil || rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) || (rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
len(rc.CustomIP) > 0 || len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 || len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil rc.LookupIPFn != nil
@ -122,13 +121,10 @@ func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Con
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS) execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
} }
// 总是注入 DialTimeout 和 Timeout(与原始代码一致) // 总是注入 DialTimeout(与原始代码一致)
if config.Network.DialTimeout > 0 { if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout) execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
} }
if config.Network.Timeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
}
// 注入 DNS 解析函数 // 注入 DNS 解析函数
if config.DNS.LookupFunc != nil { if config.DNS.LookupFunc != nil {

59
defensive_copy_test.go Normal file
View File

@ -0,0 +1,59 @@
package starnet
import (
"net/http"
"testing"
)
func TestWithRawRequestNil(t *testing.T) {
_, err := NewRequest("http://example.com", "GET", WithRawRequest(nil))
if err == nil {
t.Fatal("expected error when WithRawRequest(nil)")
}
}
func TestSetHeadersDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
headers := http.Header{
"X-Test": []string{"v1"},
}
req.SetHeaders(headers)
headers.Set("X-Test", "v2")
if got := req.GetHeader("X-Test"); got != "v1" {
t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1")
}
}
func TestSetQueriesDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
queries := map[string][]string{
"k": []string{"v1"},
}
req.SetQueries(queries)
queries["k"][0] = "v2"
queries["k"] = append(queries["k"], "v3")
got := req.config.Queries["k"]
if len(got) != 1 || got[0] != "v1" {
t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got)
}
}
func TestSetFormDataDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "POST")
form := map[string][]string{
"name": []string{"alice"},
}
req.SetFormData(form)
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}

View File

@ -19,11 +19,6 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
dialTimeout = DefaultDialTimeout dialTimeout = DefaultDialTimeout
} }
timeout := reqCtx.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
// 解析地址 // 解析地址
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
@ -47,12 +42,13 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
ipAddrs, err = reqCtx.LookupIPFn(ctx, host) ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 { } else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器 // 使用自定义 DNS 服务器
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error var lastErr error
for _, dnsServer := range reqCtx.CustomDNS { for _, dnsServer := range reqCtx.CustomDNS {
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53")) conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil { if err != nil {
lastErr = err lastErr = err
continue continue
@ -78,19 +74,15 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
} }
// 尝试连接所有地址 // 尝试连接所有地址
dialer := &net.Dialer{Timeout: dialTimeout}
var lastErr error var lastErr error
for _, addr := range addrs { for _, addr := range addrs {
conn, err := net.DialTimeout(network, addr, dialTimeout) conn, err := dialer.DialContext(ctx, network, addr)
if err != nil { if err != nil {
lastErr = err lastErr = err
continue continue
} }
// 设置总超时
if timeout > 0 {
conn.SetDeadline(time.Now().Add(timeout))
}
return conn, nil return conn, nil
} }
@ -131,6 +123,11 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
} }
// 执行 TLS 握手 // 执行 TLS 握手
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, tlsConfig) tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
conn.Close() conn.Close()

192
errors.go
View File

@ -1,8 +1,14 @@
package starnet package starnet
import ( import (
"context"
"crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt" "fmt"
"net"
"net/url"
"strings"
) )
var ( var (
@ -32,6 +38,21 @@ var (
// ErrBodyAlreadyConsumed Body 已被消费 // ErrBodyAlreadyConsumed Body 已被消费
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed") ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
// ErrRespBodyTooLarge 响应体超过允许上限
ErrRespBodyTooLarge = errors.New("starnet: response body too large")
// ErrPingInvalidTimeout ping 超时参数无效
ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout")
// ErrPingPermissionDenied ping 需要更高权限raw socket
ErrPingPermissionDenied = errors.New("starnet: ping permission denied")
// ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持
ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported")
// ErrPingNoResolvedTarget ping 目标无法解析为可用地址
ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved")
) )
// wrapError 包装错误,添加上下文信息 // wrapError 包装错误,添加上下文信息
@ -56,3 +77,174 @@ var (
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available. // ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
ErrNoTLSConfig = errors.New("starnet: no TLS config available") ErrNoTLSConfig = errors.New("starnet: no TLS config available")
) )
// ErrorKind is a normalized high-level category for request errors.
type ErrorKind string
const (
ErrorKindNone ErrorKind = "none"
ErrorKindCanceled ErrorKind = "canceled"
ErrorKindTimeout ErrorKind = "timeout"
ErrorKindDNS ErrorKind = "dns"
ErrorKindTLS ErrorKind = "tls"
ErrorKindProxy ErrorKind = "proxy"
ErrorKindOther ErrorKind = "other"
)
// IsCanceled reports whether err is a cancellation-related error.
func IsCanceled(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "context canceled") ||
strings.Contains(msg, "operation was canceled") ||
strings.Contains(msg, "request canceled")
}
// ClassifyError maps low-level errors to a stable category for business handling.
func ClassifyError(err error) ErrorKind {
if err == nil {
return ErrorKindNone
}
if IsCanceled(err) {
return ErrorKindCanceled
}
if IsProxy(err) {
return ErrorKindProxy
}
if IsDNS(err) {
return ErrorKindDNS
}
if IsTLS(err) {
return ErrorKindTLS
}
if IsTimeout(err) {
return ErrorKindTimeout
}
return ErrorKindOther
}
// IsTimeout reports whether err is a timeout-related error.
func IsTimeout(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) && uerr.Timeout() {
return true
}
var nerr net.Error
if errors.As(err, &nerr) && nerr.Timeout() {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded")
}
// IsDNS reports whether err is a DNS resolution related error.
func IsDNS(err error) bool {
if err == nil {
return false
}
var derr *net.DNSError
if errors.As(err, &derr) {
return true
}
msg := strings.ToLower(err.Error())
if strings.Contains(msg, "no such host") ||
strings.Contains(msg, "server misbehaving") ||
strings.Contains(msg, "temporary failure in name resolution") {
return true
}
return strings.Contains(msg, "lookup ") &&
(strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout"))
}
// IsTLS reports whether err is TLS/Certificate related.
func IsTLS(err error) bool {
if err == nil {
return false
}
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) {
return true
}
var recErr tls.RecordHeaderError
if errors.As(err, &recErr) {
return true
}
var uaErr x509.UnknownAuthorityError
if errors.As(err, &uaErr) {
return true
}
var hnErr x509.HostnameError
if errors.As(err, &hnErr) {
return true
}
var certErr x509.CertificateInvalidError
if errors.As(err, &certErr) {
return true
}
var rootsErr x509.SystemRootsError
if errors.As(err, &rootsErr) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:")
}
// IsProxy reports whether err is proxy related.
func IsProxy(err error) bool {
if err == nil {
return false
}
if isProxyMessage(strings.ToLower(err.Error())) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) {
if strings.Contains(strings.ToLower(uerr.Op), "proxy") {
return true
}
if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) {
return true
}
}
var opErr *net.OpError
if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") {
return true
}
return false
}
func isProxyMessage(msg string) bool {
return strings.Contains(msg, "proxyconnect") ||
strings.Contains(msg, "proxy error") ||
strings.Contains(msg, "proxy authentication required") ||
strings.Contains(msg, "proxy: unknown scheme") ||
strings.Contains(msg, "socks connect") ||
strings.Contains(msg, "socks5")
}

116
errors_classify_test.go Normal file
View File

@ -0,0 +1,116 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"testing"
)
type timeoutErr struct{}
func (timeoutErr) Error() string { return "i/o timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
func TestIsTimeout(t *testing.T) {
if !IsTimeout(context.DeadlineExceeded) {
t.Fatal("context deadline should be timeout")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: timeoutErr{},
}
if !IsTimeout(uerr) {
t.Fatal("url timeout error should be timeout")
}
if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) {
t.Fatal("wrapped timeout should be timeout")
}
if IsTimeout(errors.New("plain error")) {
t.Fatal("plain error must not be timeout")
}
}
func TestIsDNS(t *testing.T) {
dnsErr := &net.DNSError{
Err: "no such host",
Name: "example.invalid",
IsNotFound: true,
}
if !IsDNS(dnsErr) {
t.Fatal("dns error should be dns")
}
if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) {
t.Fatal("wrapped dns error should be dns")
}
if !IsDNS(errors.New("lookup example.invalid: no such host")) {
t.Fatal("lookup no such host should be dns")
}
if IsDNS(errors.New("connection reset by peer")) {
t.Fatal("non dns error should not be dns")
}
}
func TestIsTLS(t *testing.T) {
tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"}
if !IsTLS(tlsErr) {
t.Fatal("tls record header error should be tls")
}
if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) {
t.Fatal("wrapped tls error should be tls")
}
if !IsTLS(errors.New("x509: certificate signed by unknown authority")) {
t.Fatal("x509 error text should be tls")
}
if !IsTLS(ErrNotTLS) {
t.Fatal("ErrNotTLS should be tls related")
}
if IsTLS(errors.New("plain error")) {
t.Fatal("plain error should not be tls")
}
}
func TestIsProxy(t *testing.T) {
raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused")
if !IsProxy(raw) {
t.Fatal("proxyconnect error should be proxy")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: raw,
}
if !IsProxy(uerr) {
t.Fatal("wrapped proxy error should be proxy")
}
opErr := &net.OpError{
Op: "proxyconnect",
Net: "tcp",
Err: errors.New("connect failed"),
}
if !IsProxy(opErr) {
t.Fatal("net.OpError proxyconnect should be proxy")
}
if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) {
t.Fatal("non proxy dial error should not be proxy")
}
}

49
errors_kind_test.go Normal file
View File

@ -0,0 +1,49 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
)
func TestIsCanceled(t *testing.T) {
if !IsCanceled(context.Canceled) {
t.Fatal("context canceled should be canceled")
}
if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatal("wrapped context canceled should be canceled")
}
if IsCanceled(context.DeadlineExceeded) {
t.Fatal("deadline exceeded must not be canceled")
}
}
func TestClassifyError(t *testing.T) {
dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true}
tlsErr := tls.RecordHeaderError{Msg: "bad tls record"}
tests := []struct {
name string
err error
want ErrorKind
}{
{name: "nil", err: nil, want: ErrorKindNone},
{name: "canceled", err: context.Canceled, want: ErrorKindCanceled},
{name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy},
{name: "dns", err: dnsErr, want: ErrorKindDNS},
{name: "tls", err: tlsErr, want: ErrorKindTLS},
{name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout},
{name: "other", err: errors.New("boom"), want: ErrorKindOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ClassifyError(tt.err); got != tt.want {
t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err)
}
})
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic"
"time" "time"
"b612.me/starnet" "b612.me/starnet"
@ -198,3 +199,58 @@ func ExampleWithTimeout() {
fmt.Println(resp.StatusCode) fmt.Println(resp.StatusCode)
// Output: 200 // Output: 200
} }
func ExampleWithRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithRetry(2,
starnet.WithRetryBackoff(0, 0, 1),
starnet.WithRetryJitter(0),
),
)
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 3
}
func ExampleRequest_SetRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost).
SetBodyString("hello"). // 可重放 body 才能安全重试
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 2
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -12,9 +13,9 @@ import (
) )
// WithTimeout 设置请求总超时时间 // WithTimeout 设置请求总超时时间
// timeout > 0: 使用该超时 // timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 使用 Client 默认超时 // timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0 // timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt { func WithTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error { return func(r *Request) error {
r.config.Network.Timeout = timeout r.config.Network.Timeout = timeout
@ -371,9 +372,23 @@ func WithAutoFetch(auto bool) RequestOpt {
} }
} }
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
// WithRawRequest 设置原始请求 // WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt { func WithRawRequest(httpReq *http.Request) RequestOpt {
return func(r *Request) error { return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq r.httpReq = httpReq
r.doRaw = true r.doRaw = true
return nil return nil

524
ping.go
View File

@ -1,12 +1,31 @@
package starnet package starnet
import ( import (
"bytes" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"net" "net"
"os"
"strings"
"sync/atomic"
"time" "time"
) )
const (
icmpTypeEchoReplyV4 = 0
icmpTypeEchoRequestV4 = 8
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpHeaderLen = 8
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
defaultPingableCount = 3
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP struct { type ICMP struct {
Type uint8 Type uint8
Code uint8 Code uint8
@ -15,52 +34,126 @@ type ICMP struct {
SequenceNum uint16 SequenceNum uint16
} }
func getICMP(seq uint16) ICMP { type pingSocketSpec struct {
network string
family int
requestType uint8
replyType uint8
}
// PingOptions controls ping probing behavior.
type PingOptions struct {
Count int // ping attempts for Pingable, default 3
Timeout time.Duration // per-attempt timeout, default 2s
Interval time.Duration // delay between attempts, default 0
Deadline time.Time // overall deadline for Pingable/PingWithContext
PreferIPv4 bool // prefer IPv4 targets
PreferIPv6 bool // prefer IPv6 targets
SourceIP net.IP // optional source IP for raw socket bind
PayloadSize int // ICMP payload bytes, default 0
}
type PingResult struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var pingIdentifierSeed uint32
func nextPingIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&pingIdentifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func pingPayload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for i := 0; i < len(payload); i++ {
payload[i] = byte(i)
}
return payload
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{ icmp := ICMP{
Type: 8, Type: typ,
Code: 0, Code: 0,
CheckSum: 0, CheckSum: 0,
Identifier: 0, Identifier: identifier,
SequenceNum: seq, SequenceNum: seq,
} }
var buffer bytes.Buffer buf := marshalICMPPacket(icmp, payload)
binary.Write(&buffer, binary.BigEndian, icmp) icmp.CheckSum = checkSum(buf)
icmp.CheckSum = checkSum(buffer.Bytes())
buffer.Reset()
return icmp return icmp
} }
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) { func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
var res PingResult var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
if destAddr == nil || destAddr.IP == nil {
return res, fmt.Errorf("destination ip is nil")
}
res.RemoteIP = destAddr.String() res.RemoteIP = destAddr.String()
conn, err := net.DialIP("ip:icmp", nil, destAddr)
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
if err != nil { if err != nil {
return res, err return res, err
} }
defer conn.Close()
var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)
if _, err := conn.Write(buffer.Bytes()); err != nil { conn, err := net.DialIP(spec.network, localAddr, destAddr)
return res, err if err != nil {
return res, normalizePingDialError(err)
}
defer conn.Close()
packet := marshalICMPPacket(icmp, payload)
if _, err := conn.Write(packet); err != nil {
return res, wrapError(err, "ping write request")
} }
tStart := time.Now() tStart := time.Now()
deadline := tStart.Add(timeout)
conn.SetReadDeadline((time.Now().Add(timeout))) if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
recv := make([]byte, 1024) }
res.RecvCount, err = conn.Read(recv) if err := conn.SetReadDeadline(deadline); err != nil {
return res, wrapError(err, "ping set read deadline")
if err != nil {
return res, err
} }
tEnd := time.Now() doneCh := make(chan struct{})
res.Duration = tEnd.Sub(tStart) go func() {
select {
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
case <-doneCh:
}
}()
defer close(doneCh)
return res, err recv := make([]byte, icmpReadBufSz)
for {
n, err := conn.Read(recv)
if err != nil {
if ctx.Err() != nil {
return res, wrapError(ctx.Err(), "ping context done")
}
return res, wrapError(err, "ping read reply")
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(tStart)
return res, nil
}
}
} }
func checkSum(data []byte) uint16 { func checkSum(data []byte) uint16 {
@ -75,36 +168,375 @@ func checkSum(data []byte) uint16 {
length -= 2 length -= 2
} }
if length > 0 { if length > 0 {
sum += uint32(data[index]) sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
} }
sum += (sum >> 16)
return uint16(^sum) return uint16(^sum)
} }
type PingResult struct { func marshalICMP(icmp ICMP) []byte {
Duration time.Duration return marshalICMPPacket(icmp, nil)
RecvCount int
RemoteIP string
} }
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
var res PingResult buf := make([]byte, icmpHeaderLen+len(payload))
ipAddr, err := net.ResolveIPAddr("ip", ip) buf[0] = icmp.Type
if err != nil { buf[1] = icmp.Code
return res, err binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
}
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, off := range candidateICMPOffsets(packet, family) {
if off < 0 || off+icmpHeaderLen > len(packet) {
continue
} }
icmp := getICMP(uint16(seq)) if packet[off] != expectedType || packet[off+1] != 0 {
return sendICMPRequest(icmp, ipAddr, timeout) continue
} }
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { continue
for i := 0; i < retryLimit; i++ { }
_, err := Ping(ip, 29, timeout) if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
if err != nil {
continue continue
} }
return true return true
} }
return false return false
} }
func candidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
ver := packet[0] >> 4
if ver == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return dedupOffsets(offsets)
}
func dedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
m := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, off := range offsets {
if _, ok := m[off]; ok {
continue
}
m[off] = struct{}{}
out = append(out, off)
}
return out
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
if ip == nil {
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
}
if ip4 := ip.To4(); ip4 != nil {
return pingSocketSpec{
network: "ip4:icmp",
family: 4,
requestType: icmpTypeEchoRequestV4,
replyType: icmpTypeEchoReplyV4,
}, nil
}
if ip16 := ip.To16(); ip16 != nil {
return pingSocketSpec{
network: "ip6:ipv6-icmp",
family: 6,
requestType: icmpTypeEchoRequestV6,
replyType: icmpTypeEchoReplyV6,
}, nil
}
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
}
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
if sourceIP == nil {
return nil, nil
}
if sourceIP.To16() == nil {
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
}
if family == 4 && sourceIP.To4() == nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
}
if family == 6 && sourceIP.To4() != nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
}
return &net.IPAddr{IP: sourceIP}, nil
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = e
}
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = e
}
if len(targets) > 0 {
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, ErrPingNoResolvedTarget
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
return ordered
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
return ordered
}
func normalizePingDialError(err error) error {
if err == nil {
return nil
}
msg := strings.ToLower(err.Error())
if errors.Is(err, os.ErrPermission) ||
strings.Contains(msg, "operation not permitted") ||
strings.Contains(msg, "permission denied") {
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
}
if strings.Contains(msg, "unknown network") ||
strings.Contains(msg, "protocol not available") ||
strings.Contains(msg, "address family not supported by protocol") ||
strings.Contains(msg, "socket type not supported") {
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
}
return wrapError(err, "ping dial")
}
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
out := PingOptions{
Count: defaultCount,
Timeout: defaultTimeout,
Interval: 0,
PayloadSize: 0,
}
if opts != nil {
out = *opts
if out.Count == 0 {
out.Count = defaultCount
}
if out.Timeout == 0 {
out.Timeout = defaultTimeout
}
}
if out.Count < 0 {
return out, fmt.Errorf("ping count must be >= 0")
}
if out.Timeout <= 0 {
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
}
if out.Interval < 0 {
return out, fmt.Errorf("ping interval must be >= 0")
}
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
}
if out.SourceIP != nil && out.SourceIP.To16() == nil {
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
}
return out, nil
}
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
if err != nil {
return res, wrapError(err, "resolve ping target")
}
payload := pingPayload(opts.PayloadSize)
var lastErr error
for _, target := range targets {
spec, err := socketSpecForIP(target.IP)
if err != nil {
lastErr = err
continue
}
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
if err == nil {
return resp, nil
}
// 权限问题通常与地址族无关,继续重试意义不大。
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
lastErr = err
}
if lastErr != nil {
return res, wrapError(lastErr, "ping all resolved targets failed")
}
return res, ErrPingNoResolvedTarget
}
// PingWithContext sends one ICMP echo request with context cancel support.
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
opts, err := normalizePingOptions(&PingOptions{
Count: 1,
Timeout: timeout,
}, 1, timeout)
if err != nil {
return PingResult{}, err
}
if !opts.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
defer cancel()
}
return pingOnceWithOptions(ctx, host, seq, opts)
}
// Ping sends one ICMP echo request.
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
return PingWithContext(context.Background(), ip, seq, timeout)
}
// Pingable checks host reachability with retry options.
func Pingable(host string, opts *PingOptions) (bool, error) {
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
if err != nil {
return false, err
}
ctx := context.Background()
if !cfg.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
defer cancel()
}
var lastErr error
for i := 0; i < cfg.Count; i++ {
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
if err == nil {
return true, nil
}
lastErr = err
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
}
if i < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():
timer.Stop()
return false, wrapError(ctx.Err(), "pingable context done")
case <-timer.C:
}
}
}
if lastErr == nil {
lastErr = ErrPingNoResolvedTarget
}
return false, lastErr
}
// IsIpPingable keeps backward-compatible bool-only behavior.
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
if retryLimit <= 0 {
return false
}
ok, _ := Pingable(ip, &PingOptions{
Count: retryLimit,
Timeout: timeout,
})
return ok
}

214
ping_logic_test.go Normal file
View File

@ -0,0 +1,214 @@
package starnet
import (
"context"
"errors"
"net"
"testing"
"time"
)
func buildICMPPacket(typ uint8, identifier, seq uint16) []byte {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, nil)
cs := checkSum(buf)
buf[2] = byte(cs >> 8)
buf[3] = byte(cs)
return buf
}
func TestNextPingIdentifierChanges(t *testing.T) {
id1 := nextPingIdentifier()
id2 := nextPingIdentifier()
if id1 == id2 {
t.Fatalf("identifier should change between calls: %d == %d", id1, id2)
}
}
func TestIsExpectedEchoReplyIPv4(t *testing.T) {
identifier := uint16(0x1234)
seq := uint16(0x0102)
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 echo reply to match")
}
if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) {
t.Fatal("mismatched sequence should not match")
}
}
func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) {
identifier := uint16(0x1111)
seq := uint16(0x2222)
ipHeader := make([]byte, 20)
ipHeader[0] = 0x45 // version=4, ihl=5
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
packet := append(ipHeader, reply...)
if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 packet with header to match")
}
}
func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) {
identifier := uint16(0xabcd)
seq := uint16(0x00ff)
ipv6Header := make([]byte, 40)
ipv6Header[0] = 0x60 // version=6
reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq)
packet := append(ipv6Header, reply...)
if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) {
t.Fatal("expected IPv6 packet with header to match")
}
}
func TestPingInvalidTimeout(t *testing.T) {
_, err := Ping("127.0.0.1", 1, 0)
if err == nil {
t.Fatal("expected error for non-positive timeout")
}
if !errors.Is(err, ErrPingInvalidTimeout) {
t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err)
}
}
func TestIsIPPingableInvalidRetry(t *testing.T) {
if IsIpPingable("127.0.0.1", time.Millisecond, 0) {
t.Fatal("retryLimit=0 should return false")
}
}
func TestSocketSpecForIP(t *testing.T) {
v4, err := socketSpecForIP(net.ParseIP("127.0.0.1"))
if err != nil {
t.Fatalf("unexpected v4 error: %v", err)
}
if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 {
t.Fatalf("unexpected v4 spec: %+v", v4)
}
v6, err := socketSpecForIP(net.ParseIP("::1"))
if err != nil {
t.Fatalf("unexpected v6 error: %v", err)
}
if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 {
t.Fatalf("unexpected v6 spec: %+v", v6)
}
_, err = socketSpecForIP(nil)
if err == nil {
t.Fatal("expected error for nil ip")
}
if !errors.Is(err, ErrInvalidIP) {
t.Fatalf("expected ErrInvalidIP, got: %v", err)
}
}
func TestResolvePingTargetsLiteral(t *testing.T) {
v4, err := resolvePingTargets("127.0.0.1", false, false)
if err != nil {
t.Fatalf("unexpected v4 resolve error: %v", err)
}
if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil {
t.Fatalf("unexpected v4 targets: %+v", v4)
}
v6, err := resolvePingTargets("::1", false, false)
if err != nil {
t.Fatalf("unexpected v6 resolve error: %v", err)
}
if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil {
t.Fatalf("unexpected v6 targets: %+v", v6)
}
}
func TestNormalizePingDialError(t *testing.T) {
perr := normalizePingDialError(errors.New("socket: operation not permitted"))
if !errors.Is(perr, ErrPingPermissionDenied) {
t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr)
}
uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp"))
if !errors.Is(uerr, ErrPingProtocolUnsupported) {
t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr)
}
}
func TestOrderPingTargets(t *testing.T) {
targets := []*net.IPAddr{
{IP: net.ParseIP("::1")},
{IP: net.ParseIP("127.0.0.1")},
}
v4First := orderPingTargets(targets, true, false)
if v4First[0].IP.To4() == nil {
t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP)
}
v6First := orderPingTargets(targets, false, true)
if v6First[0].IP.To4() != nil {
t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP)
}
}
func TestNormalizePingOptions(t *testing.T) {
opts, err := normalizePingOptions(nil, 3, 2*time.Second)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if opts.Count != 3 || opts.Timeout != 2*time.Second {
t.Fatalf("unexpected defaults: %+v", opts)
}
_, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative count")
}
_, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative timeout")
}
_, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for too large payload")
}
}
func TestPingWithContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second)
if err == nil {
t.Fatal("expected canceled error")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got: %v", err)
}
}
func TestPingableInvalidOptions(t *testing.T) {
_, err := Pingable("127.0.0.1", &PingOptions{Count: -1})
if err == nil {
t.Fatal("expected invalid count error")
}
_, err = Pingable("127.0.0.1", &PingOptions{Interval: -1})
if err == nil {
t.Fatal("expected invalid interval error")
}
}

View File

@ -12,6 +12,7 @@ import (
type Request struct { type Request struct {
ctx context.Context ctx context.Context
execCtx context.Context // 执行时的 context注入了配置 execCtx context.Context // 执行时的 context注入了配置
cancel context.CancelFunc
url string url string
method string method string
err error // 累积的错误 err error // 累积的错误
@ -20,6 +21,7 @@ type Request struct {
client *Client client *Client
httpClient *http.Client httpClient *http.Client
httpReq *http.Request httpReq *http.Request
retry *retryPolicy
applied bool // 是否已应用配置 applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改) doRaw bool // 是否使用原始请求(不修改)
@ -160,6 +162,7 @@ func (r *Request) Clone() *Request {
config: r.config.Clone(), config: r.config.Clone(),
client: r.client, client: r.client,
httpClient: r.httpClient, httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
applied: false, // 重置应用状态 applied: false, // 重置应用状态
doRaw: r.doRaw, doRaw: r.doRaw,
autoFetch: r.autoFetch, autoFetch: r.autoFetch,
@ -318,6 +321,14 @@ func (r *Request) Do() (*Response, error) {
return nil, r.err return nil, r.err
} }
if r.hasRetryPolicy() {
return r.doWithRetry()
}
return r.doOnce()
}
func (r *Request) doOnce() (*Response, error) {
// 准备请求 // 准备请求
if err := r.prepare(); err != nil { if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request") return nil, wrapError(err, "prepare request")
@ -326,6 +337,10 @@ func (r *Request) Do() (*Response, error) {
// 执行请求 // 执行请求
httpResp, err := r.httpClient.Do(r.httpReq) httpResp, err := r.httpClient.Do(r.httpReq)
if err != nil { if err != nil {
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return &Response{ return &Response{
Response: &http.Response{}, Response: &http.Response{},
request: r, request: r,
@ -334,19 +349,33 @@ func (r *Request) Do() (*Response, error) {
}, wrapError(err, "do request") }, wrapError(err, "do request")
} }
rawBody := httpResp.Body
if r.cancel != nil {
rawBody = &cancelReadCloser{
ReadCloser: httpResp.Body,
cancel: r.cancel,
}
}
// 创建响应 // 创建响应
resp := &Response{ resp := &Response{
Response: httpResp, Response: httpResp,
request: r, request: r,
httpClient: r.httpClient, httpClient: r.httpClient,
cancel: r.cancel,
body: &Body{ body: &Body{
raw: httpResp.Body, raw: rawBody,
maxBytes: r.config.MaxRespBodyBytes,
}, },
} }
r.cancel = nil
// 自动获取响应体 // 自动获取响应体
if r.autoFetch { if r.autoFetch {
resp.body.readAll() if err := resp.body.readAll(); err != nil {
_ = resp.Close()
return resp, err
}
} }
return resp, nil return resp, nil
@ -371,3 +400,28 @@ func (r *Request) Put() (*Response, error) {
func (r *Request) Delete() (*Response, error) { func (r *Request) Delete() (*Response, error) {
return r.SetMethod(http.MethodDelete).Do() return r.SetMethod(http.MethodDelete).Do()
} }
// Head 发送 HEAD 请求
func (r *Request) Head() (*Response, error) {
return r.SetMethod(http.MethodHead).Do()
}
// Patch 发送 PATCH 请求
func (r *Request) Patch() (*Response, error) {
return r.SetMethod(http.MethodPatch).Do()
}
// Options 发送 OPTIONS 请求
func (r *Request) Options() (*Response, error) {
return r.SetMethod(http.MethodOptions).Do()
}
// Trace 发送 TRACE 请求
func (r *Request) Trace() (*Response, error) {
return r.SetMethod(http.MethodTrace).Do()
}
// Connect 发送 CONNECT 请求
func (r *Request) Connect() (*Response, error) {
return r.SetMethod(http.MethodConnect).Do()
}

View File

@ -2,7 +2,9 @@ package starnet
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -65,7 +67,7 @@ func (r *Request) SetFormData(data map[string][]string) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Body.FormData = data r.config.Body.FormData = cloneStringMapSlice(data)
return r return r
} }
@ -345,12 +347,13 @@ func (r *Request) prepare() error {
return err // ← 失败时不设置 applied return err // ← 失败时不设置 applied
} }
} }
// 原始模式不修改请求内容
if r.doRaw { if r.httpReq == nil {
r.applied = true return fmt.Errorf("http request is nil")
return nil
} }
// 原始模式不修改请求内容
if !r.doRaw {
// 应用查询参数 // 应用查询参数
if len(r.config.Queries) > 0 { if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query() q := r.httpReq.URL.Query()
@ -405,9 +408,20 @@ func (r *Request) prepare() error {
if r.config.TLS.Config != nil && r.httpReq.URL != nil { if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
} }
}
// 注入配置到 context execCtx := r.ctx
r.execCtx = injectRequestConfig(r.ctx, r.config) if !r.doRaw {
// raw 模式下不注入请求级网络配置,只应用 context/超时。
execCtx = injectRequestConfig(execCtx, r.config)
}
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
if r.config.Network.Timeout > 0 {
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
r.execCtx = execCtx
r.httpReq = r.httpReq.WithContext(r.execCtx) r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true r.applied = true
@ -416,50 +430,19 @@ func (r *Request) prepare() error {
// buildHTTPClient 构建 HTTP Client // buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) { func (r *Request) buildHTTPClient() (*http.Client, error) {
applyTimeoutOverride := func(base *http.Client) *http.Client {
// 没有 base 时兜底
if base == nil {
base = &http.Client{}
}
rt := r.config.Network.Timeout
// 语义:
// rt < 0 : 本次请求禁用超时Timeout = 0
// rt = 0 : 沿用 base.Timeout
// rt > 0 : 本次请求超时覆盖
if rt == 0 {
return base
}
clone := &http.Client{
Transport: base.Transport,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
}
if rt < 0 {
clone.Timeout = 0
} else {
clone.Timeout = rt
}
return clone
}
// 优先使用请求关联的 Client // 优先使用请求关联的 Client
if r.client != nil { if r.client != nil {
return applyTimeoutOverride(r.client.HTTPClient()), nil return r.client.HTTPClient(), nil
} }
// 自定义 Transport // 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil { if r.config.CustomTransport && r.config.Transport != nil {
base := &http.Client{ return &http.Client{
Transport: &Transport{base: r.config.Transport}, Transport: &Transport{base: r.config.Transport},
Timeout: 0, Timeout: 0,
} }, nil
return applyTimeoutOverride(base), nil
} }
// 默认全局 client // 默认全局 client
return applyTimeoutOverride(DefaultHTTPClient()), nil return DefaultHTTPClient(), nil
} }

View File

@ -10,9 +10,9 @@ import (
) )
// SetTimeout 设置请求总超时时间 // SetTimeout 设置请求总超时时间
// timeout > 0: 使用该超时 // timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 使用 Client 默认超时 // timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0 // timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request { func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil { if r.err != nil {
return r return r
@ -194,6 +194,19 @@ func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
return r return r
} }
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
if r.err != nil {
return r
}
if maxBytes < 0 {
r.err = fmt.Errorf("max response body bytes must be >= 0")
return r
}
r.config.MaxRespBodyBytes = maxBytes
return r
}
// AddQuery 添加查询参数 // AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request { func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil { if r.err != nil {
@ -217,7 +230,7 @@ func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil { if r.err != nil {
return r return r
} }
r.config.Queries = queries r.config.Queries = cloneStringMapSlice(queries)
return r return r
} }

View File

@ -36,7 +36,7 @@ func (r *Request) SetHeaders(headers http.Header) *Request {
if r.doRaw { if r.doRaw {
return r return r
} }
r.config.Headers = headers r.config.Headers = cloneHeader(headers)
return r return r
} }

View File

@ -0,0 +1,43 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestConvenienceMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
tests := []struct {
name string
method string
do func(r *Request) (*Response, error)
}{
{name: "Head", method: http.MethodHead, do: (*Request).Head},
{name: "Patch", method: http.MethodPatch, do: (*Request).Patch},
{name: "Options", method: http.MethodOptions, do: (*Request).Options},
{name: "Trace", method: http.MethodTrace, do: (*Request).Trace},
{name: "Connect", method: http.MethodConnect, do: (*Request).Connect},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := NewSimpleRequest(server.URL, http.MethodGet)
resp, err := tt.do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if got := resp.Header.Get("X-Method"); got != tt.method {
t.Fatalf("method=%s want=%s", got, tt.method)
}
})
}
}

View File

@ -13,6 +13,7 @@ type Response struct {
*http.Response *http.Response
request *Request request *Request
httpClient *http.Client httpClient *http.Client
cancel func()
body *Body body *Body
} }
@ -21,9 +22,26 @@ type Body struct {
raw io.ReadCloser raw io.ReadCloser
data []byte data []byte
consumed bool consumed bool
maxBytes int64
mu sync.Mutex mu sync.Mutex
} }
type cancelReadCloser struct {
io.ReadCloser
cancel func()
once sync.Once
}
func (c *cancelReadCloser) Close() error {
err := c.ReadCloser.Close()
c.once.Do(func() {
if c.cancel != nil {
c.cancel()
}
})
return err
}
// Request 获取原始请求 // Request 获取原始请求
func (r *Response) Request() *Request { func (r *Response) Request() *Request {
return r.request return r.request
@ -42,6 +60,10 @@ func (r *Response) Close() error {
if r.body != nil && r.body.raw != nil { if r.body != nil && r.body.raw != nil {
return r.body.raw.Close() return r.body.raw.Close()
} }
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return nil return nil
} }
@ -70,14 +92,24 @@ func (b *Body) readAll() error {
return nil return nil
} }
data, err := io.ReadAll(b.raw) reader := io.Reader(b.raw)
if b.maxBytes > 0 {
reader = io.LimitReader(b.raw, b.maxBytes+1)
}
data, err := io.ReadAll(reader)
if err != nil { if err != nil {
return wrapError(err, "read response body") return wrapError(err, "read response body")
} }
if b.maxBytes > 0 && int64(len(data)) > b.maxBytes {
b.consumed = true
_ = b.raw.Close()
return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes)
}
b.data = data b.data = data
b.consumed = true b.consumed = true
b.raw.Close() _ = b.raw.Close()
return nil return nil
} }

84
response_limit_test.go Normal file
View File

@ -0,0 +1,84 @@
package starnet
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestWithMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
resp, err := Get(server.URL, WithMaxRespBodyBytes(4))
if err != nil {
t.Fatalf("unexpected request error: %v", err)
}
defer resp.Close()
_, err = resp.Body().Bytes()
if err == nil {
t.Fatal("expected body too large error")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("1234"))
}))
defer server.Close()
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetMaxRespBodyBytes(4).
Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("unexpected read error: %v", err)
}
if body != "1234" {
t.Fatalf("body=%q want=1234", body)
}
}
func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
_, err := NewSimpleRequest(server.URL, http.MethodGet).
SetAutoFetch(true).
SetMaxRespBodyBytes(4).
Do()
if err == nil {
t.Fatal("expected body too large error with auto fetch")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytesInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1)
if req.Err() == nil {
t.Fatal("expected error for negative max bytes")
}
}
func TestWithMaxRespBodyBytesInvalid(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1))
if err == nil {
t.Fatal("expected error for negative max bytes")
}
}

423
retry.go Normal file
View File

@ -0,0 +1,423 @@
package starnet
import (
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := r.ctx
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
if r.httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
raw := r.httpReq.Clone(ctx)
if r.httpReq.GetBody != nil {
body, err := r.httpReq.GetBody()
if err != nil {
return nil, wrapError(err, "get raw request body")
}
raw.Body = body
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
return nil, fmt.Errorf("raw request body is not replayable")
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
// Reader / stream body 通常不可重放,保守地不重试。
if r.config.Body.Reader != nil {
return false
}
for _, f := range r.config.Body.Files {
if f.FileData != nil || f.FilePath == "" {
return false
}
}
return true
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}

298
retry_test.go Normal file
View File

@ -0,0 +1,298 @@
package starnet
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithRetrySmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
resp, err := Get(s.URL,
WithRetry(2,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 3 {
t.Fatalf("hits=%d want=3", hits)
}
}
func TestWithRetryResponseRequestPointerStable(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.Request() != req {
t.Fatal("response request pointer should point to original request")
}
}
func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyReader(strings.NewReader("payload")).
SetRetry(3,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable)
}
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 2 {
t.Fatalf("hits=%d want=2", hits)
}
}
func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw")))
if rawReq.GetBody != nil {
t.Fatal("raw request GetBody should be nil in this test")
}
req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)).
SetRetry(2,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
time.Sleep(80 * time.Millisecond)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetTimeout(120*time.Millisecond).
SetRetry(3,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
_, err := req.Do()
if err == nil {
t.Fatal("expected timeout error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context deadline exceeded, got: %v", err)
}
if h := atomic.LoadInt32(&hits); h > 2 {
t.Fatalf("hits=%d want<=2 under tight timeout budget", h)
}
}
func TestSetRetryInvalidMax(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1)
if req.Err() == nil {
t.Fatal("expected error for negative retry max")
}
}
func TestSetRetrySeriesSmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(2).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryStatuses(http.StatusTooManyRequests)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 3 {
t.Fatalf("hits=%d want=3", h)
}
}
func TestSetRetryIdempotentOnlyWithPost(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 2 {
t.Fatalf("hits=%d want=2", h)
}
}
func TestSetRetryOnErrorOverridesDefault(t *testing.T) {
var dials int32
dialErr := errors.New("dial failed")
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.AddInt32(&dials, 1)
return nil, dialErr
}).
SetRetry(1).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryOnError(func(err error) bool {
return true
})
_, err := req.Do()
if err == nil {
t.Fatal("expected error")
}
if h := atomic.LoadInt32(&dials); h != 2 {
t.Fatalf("dial attempts=%d want=2", h)
}
}
func TestSetRetryOptionRequireEnableRetry(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2)
if req.Err() == nil {
t.Fatal("expected error when setting retry options before SetRetry")
}
if !strings.Contains(req.Err().Error(), "call SetRetry first") {
t.Fatalf("unexpected error: %v", req.Err())
}
}

115
timeout_refactor_test.go Normal file
View File

@ -0,0 +1,115 @@
package starnet
import (
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
client := NewClientNoErr()
baseTimeout := client.HTTPClient().Timeout
_, err := client.Get(s.URL, WithTimeout(80*time.Millisecond))
if err == nil {
t.Fatal("expected request timeout error")
}
if client.HTTPClient().Timeout != baseTimeout {
t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout)
}
resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond))
if err != nil {
t.Fatalf("second request should succeed with larger timeout: %v", err)
}
defer resp.Close()
}
func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
client := NewClientNoErr()
resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond))
if err != nil {
t.Fatalf("first request should succeed: %v", err)
}
_ = resp1.Close()
// 如果请求超时依赖连接级绝对 deadline经过该等待后复用连接会出现误超时。
time.Sleep(220 * time.Millisecond)
resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second))
if err != nil {
t.Fatalf("second request should not be affected by previous timeout window: %v", err)
}
_ = resp2.Close()
}
func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
time.Sleep(180 * time.Millisecond)
_, _ = conn.Write([]byte("x"))
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial error: %v", err)
}
defer c.Close()
if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil {
t.Fatalf("set read deadline error: %v", err)
}
buf := make([]byte, 1)
_, err = c.Read(buf)
if err == nil {
t.Fatal("expected read timeout error")
}
if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
t.Fatalf("expected net timeout error, got: %v", err)
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
t.Fatalf("clear read deadline error: %v", err)
}
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read after clearing deadline should succeed: %v", err)
}
if string(buf) != "x" {
t.Fatalf("unexpected payload: %q", string(buf))
}
<-done
}

View File

@ -84,6 +84,7 @@ type RequestConfig struct {
BasicAuth [2]string // Basic 认证 BasicAuth [2]string // Basic 认证
ContentLength int64 // 手动设置的 Content-Length ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length AutoCalcContentLength bool // 自动计算 Content-Length
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
UploadProgress UploadProgressFunc // 上传进度回调 UploadProgress UploadProgressFunc // 上传进度回调
// Transport 配置 // Transport 配置
@ -121,6 +122,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
BasicAuth: c.BasicAuth, BasicAuth: c.BasicAuth,
ContentLength: c.ContentLength, ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength, AutoCalcContentLength: c.AutoCalcContentLength,
MaxRespBodyBytes: c.MaxRespBodyBytes,
UploadProgress: c.UploadProgress, UploadProgress: c.UploadProgress,
CustomTransport: c.CustomTransport, CustomTransport: c.CustomTransport,
Transport: c.Transport, // Transport 共享 Transport: c.Transport, // Transport 共享