1. 优化ping功能
2. 新增重试机制 3. 优化错误处理逻辑
This commit is contained in:
parent
4568e17f06
commit
b5bd7595a1
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
|
.sentrux/
|
||||||
|
agent_readme.md
|
||||||
|
target.md
|
||||||
|
|||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2026 starnet contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
106
README.md
Normal file
106
README.md
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# starnet
|
||||||
|
|
||||||
|
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities.
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
- Request-level timeout by context (without mutating shared `http.Client` timeout)
|
||||||
|
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config
|
||||||
|
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses
|
||||||
|
- Response body safety guard via max body bytes limit
|
||||||
|
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`)
|
||||||
|
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios
|
||||||
|
- ICMP ping with IPv4/IPv6 target handling and option-based probing API
|
||||||
|
|
||||||
|
## Main Features
|
||||||
|
|
||||||
|
### HTTP Client and Request
|
||||||
|
|
||||||
|
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods
|
||||||
|
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect`
|
||||||
|
- Request body helpers: JSON, form data, multipart file upload, stream body
|
||||||
|
- Header/cookie/query helpers with defensive copy on key setters
|
||||||
|
- Request cloning for safe reuse in concurrent or variant calls
|
||||||
|
|
||||||
|
### Timeout and Retry
|
||||||
|
|
||||||
|
- Request timeout is applied by context deadline, not global client timeout
|
||||||
|
- Retry supports:
|
||||||
|
- max attempts
|
||||||
|
- backoff factor/base/max
|
||||||
|
- jitter
|
||||||
|
- retry status whitelist
|
||||||
|
- idempotent-only guard
|
||||||
|
- custom retry-on-error callback
|
||||||
|
- Retry keeps original request pointer in final response for consistency
|
||||||
|
|
||||||
|
### Response Handling
|
||||||
|
|
||||||
|
- `Bytes/String/JSON/Reader` helpers
|
||||||
|
- optional auto-fetch mode
|
||||||
|
- configurable max response body bytes to prevent oversized reads
|
||||||
|
|
||||||
|
### Ping Module
|
||||||
|
|
||||||
|
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable`
|
||||||
|
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size
|
||||||
|
- explicit error semantics for permission/protocol/timeout/resolve failures
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get b612.me/starnet
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
resp, err := starnet.Get(
|
||||||
|
"https://example.com",
|
||||||
|
starnet.WithTimeout(2*time.Second),
|
||||||
|
starnet.WithRetry(2,
|
||||||
|
starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2),
|
||||||
|
starnet.WithRetryJitter(0.1),
|
||||||
|
),
|
||||||
|
starnet.WithMaxRespBodyBytes(1<<20),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("request failed:", starnet.ClassifyError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println("status:", resp.StatusCode)
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
|
||||||
|
ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{
|
||||||
|
Count: 2,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
})
|
||||||
|
fmt.Println("pingable:", ok, pingErr == nil)
|
||||||
|
|
||||||
|
_ = http.MethodGet
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Stability Notes
|
||||||
|
|
||||||
|
- Raw ICMP ping may require elevated privileges on some systems.
|
||||||
|
- Integration tests that rely on external network are environment-dependent.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the Apache License 2.0.
|
||||||
|
See [LICENSE](./LICENSE).
|
||||||
|
|
||||||
@ -78,7 +78,6 @@ func needsDynamicTransport(rc *RequestContext) bool {
|
|||||||
rc.Proxy != "" ||
|
rc.Proxy != "" ||
|
||||||
rc.DialFn != nil ||
|
rc.DialFn != nil ||
|
||||||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
||||||
(rc.Timeout > 0 && rc.Timeout != DefaultTimeout) ||
|
|
||||||
len(rc.CustomIP) > 0 ||
|
len(rc.CustomIP) > 0 ||
|
||||||
len(rc.CustomDNS) > 0 ||
|
len(rc.CustomDNS) > 0 ||
|
||||||
rc.LookupIPFn != nil
|
rc.LookupIPFn != nil
|
||||||
@ -122,13 +121,10 @@ func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Con
|
|||||||
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 总是注入 DialTimeout 和 Timeout(与原始代码一致)
|
// 总是注入 DialTimeout(与原始代码一致)
|
||||||
if config.Network.DialTimeout > 0 {
|
if config.Network.DialTimeout > 0 {
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
|
||||||
}
|
}
|
||||||
if config.Network.Timeout > 0 {
|
|
||||||
execCtx = context.WithValue(execCtx, ctxKeyTimeout, config.Network.Timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入 DNS 解析函数
|
// 注入 DNS 解析函数
|
||||||
if config.DNS.LookupFunc != nil {
|
if config.DNS.LookupFunc != nil {
|
||||||
|
|||||||
59
defensive_copy_test.go
Normal file
59
defensive_copy_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithRawRequestNil(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", "GET", WithRawRequest(nil))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when WithRawRequest(nil)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetHeadersDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET")
|
||||||
|
headers := http.Header{
|
||||||
|
"X-Test": []string{"v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetHeaders(headers)
|
||||||
|
headers.Set("X-Test", "v2")
|
||||||
|
|
||||||
|
if got := req.GetHeader("X-Test"); got != "v1" {
|
||||||
|
t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetQueriesDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET")
|
||||||
|
queries := map[string][]string{
|
||||||
|
"k": []string{"v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetQueries(queries)
|
||||||
|
queries["k"][0] = "v2"
|
||||||
|
queries["k"] = append(queries["k"], "v3")
|
||||||
|
|
||||||
|
got := req.config.Queries["k"]
|
||||||
|
if len(got) != 1 || got[0] != "v1" {
|
||||||
|
t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetFormDataDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "POST")
|
||||||
|
form := map[string][]string{
|
||||||
|
"name": []string{"alice"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetFormData(form)
|
||||||
|
form["name"][0] = "bob"
|
||||||
|
form["name"] = append(form["name"], "carol")
|
||||||
|
|
||||||
|
got := req.config.Body.FormData["name"]
|
||||||
|
if len(got) != 1 || got[0] != "alice" {
|
||||||
|
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
21
dialer.go
21
dialer.go
@ -19,11 +19,6 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
|
|||||||
dialTimeout = DefaultDialTimeout
|
dialTimeout = DefaultDialTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := reqCtx.Timeout
|
|
||||||
if timeout == 0 {
|
|
||||||
timeout = DefaultTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析地址
|
// 解析地址
|
||||||
host, port, err := net.SplitHostPort(addr)
|
host, port, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -47,12 +42,13 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
|
|||||||
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
|
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
|
||||||
} else if len(reqCtx.CustomDNS) > 0 {
|
} else if len(reqCtx.CustomDNS) > 0 {
|
||||||
// 使用自定义 DNS 服务器
|
// 使用自定义 DNS 服务器
|
||||||
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for _, dnsServer := range reqCtx.CustomDNS {
|
for _, dnsServer := range reqCtx.CustomDNS {
|
||||||
conn, err := net.Dial("udp", net.JoinHostPort(dnsServer, "53"))
|
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
continue
|
continue
|
||||||
@ -78,19 +74,15 @@ func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 尝试连接所有地址
|
// 尝试连接所有地址
|
||||||
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
conn, err := net.DialTimeout(network, addr, dialTimeout)
|
conn, err := dialer.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置总超时
|
|
||||||
if timeout > 0 {
|
|
||||||
conn.SetDeadline(time.Now().Add(timeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +123,11 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 执行 TLS 握手
|
// 执行 TLS 握手
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
_ = conn.SetDeadline(deadline)
|
||||||
|
defer conn.SetDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
|
||||||
tlsConn := tls.Client(conn, tlsConfig)
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
|||||||
192
errors.go
192
errors.go
@ -1,8 +1,14 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -32,6 +38,21 @@ var (
|
|||||||
|
|
||||||
// ErrBodyAlreadyConsumed Body 已被消费
|
// ErrBodyAlreadyConsumed Body 已被消费
|
||||||
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
|
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
|
||||||
|
|
||||||
|
// ErrRespBodyTooLarge 响应体超过允许上限
|
||||||
|
ErrRespBodyTooLarge = errors.New("starnet: response body too large")
|
||||||
|
|
||||||
|
// ErrPingInvalidTimeout ping 超时参数无效
|
||||||
|
ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout")
|
||||||
|
|
||||||
|
// ErrPingPermissionDenied ping 需要更高权限(raw socket)
|
||||||
|
ErrPingPermissionDenied = errors.New("starnet: ping permission denied")
|
||||||
|
|
||||||
|
// ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持
|
||||||
|
ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported")
|
||||||
|
|
||||||
|
// ErrPingNoResolvedTarget ping 目标无法解析为可用地址
|
||||||
|
ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved")
|
||||||
)
|
)
|
||||||
|
|
||||||
// wrapError 包装错误,添加上下文信息
|
// wrapError 包装错误,添加上下文信息
|
||||||
@ -56,3 +77,174 @@ var (
|
|||||||
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
|
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
|
||||||
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
|
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrorKind is a normalized high-level category for request errors.
|
||||||
|
type ErrorKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrorKindNone ErrorKind = "none"
|
||||||
|
ErrorKindCanceled ErrorKind = "canceled"
|
||||||
|
ErrorKindTimeout ErrorKind = "timeout"
|
||||||
|
ErrorKindDNS ErrorKind = "dns"
|
||||||
|
ErrorKindTLS ErrorKind = "tls"
|
||||||
|
ErrorKindProxy ErrorKind = "proxy"
|
||||||
|
ErrorKindOther ErrorKind = "other"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsCanceled reports whether err is a cancellation-related error.
|
||||||
|
func IsCanceled(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "context canceled") ||
|
||||||
|
strings.Contains(msg, "operation was canceled") ||
|
||||||
|
strings.Contains(msg, "request canceled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassifyError maps low-level errors to a stable category for business handling.
|
||||||
|
func ClassifyError(err error) ErrorKind {
|
||||||
|
if err == nil {
|
||||||
|
return ErrorKindNone
|
||||||
|
}
|
||||||
|
if IsCanceled(err) {
|
||||||
|
return ErrorKindCanceled
|
||||||
|
}
|
||||||
|
if IsProxy(err) {
|
||||||
|
return ErrorKindProxy
|
||||||
|
}
|
||||||
|
if IsDNS(err) {
|
||||||
|
return ErrorKindDNS
|
||||||
|
}
|
||||||
|
if IsTLS(err) {
|
||||||
|
return ErrorKindTLS
|
||||||
|
}
|
||||||
|
if IsTimeout(err) {
|
||||||
|
return ErrorKindTimeout
|
||||||
|
}
|
||||||
|
return ErrorKindOther
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTimeout reports whether err is a timeout-related error.
|
||||||
|
func IsTimeout(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uerr *url.Error
|
||||||
|
if errors.As(err, &uerr) && uerr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var nerr net.Error
|
||||||
|
if errors.As(err, &nerr) && nerr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDNS reports whether err is a DNS resolution related error.
|
||||||
|
func IsDNS(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var derr *net.DNSError
|
||||||
|
if errors.As(err, &derr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
if strings.Contains(msg, "no such host") ||
|
||||||
|
strings.Contains(msg, "server misbehaving") ||
|
||||||
|
strings.Contains(msg, "temporary failure in name resolution") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(msg, "lookup ") &&
|
||||||
|
(strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTLS reports whether err is TLS/Certificate related.
|
||||||
|
func IsTLS(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var recErr tls.RecordHeaderError
|
||||||
|
if errors.As(err, &recErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uaErr x509.UnknownAuthorityError
|
||||||
|
if errors.As(err, &uaErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var hnErr x509.HostnameError
|
||||||
|
if errors.As(err, &hnErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var certErr x509.CertificateInvalidError
|
||||||
|
if errors.As(err, &certErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var rootsErr x509.SystemRootsError
|
||||||
|
if errors.As(err, &rootsErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProxy reports whether err is proxy related.
|
||||||
|
func IsProxy(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if isProxyMessage(strings.ToLower(err.Error())) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uerr *url.Error
|
||||||
|
if errors.As(err, &uerr) {
|
||||||
|
if strings.Contains(strings.ToLower(uerr.Op), "proxy") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProxyMessage(msg string) bool {
|
||||||
|
return strings.Contains(msg, "proxyconnect") ||
|
||||||
|
strings.Contains(msg, "proxy error") ||
|
||||||
|
strings.Contains(msg, "proxy authentication required") ||
|
||||||
|
strings.Contains(msg, "proxy: unknown scheme") ||
|
||||||
|
strings.Contains(msg, "socks connect") ||
|
||||||
|
strings.Contains(msg, "socks5")
|
||||||
|
}
|
||||||
|
|||||||
116
errors_classify_test.go
Normal file
116
errors_classify_test.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeoutErr struct{}
|
||||||
|
|
||||||
|
func (timeoutErr) Error() string { return "i/o timeout" }
|
||||||
|
func (timeoutErr) Timeout() bool { return true }
|
||||||
|
func (timeoutErr) Temporary() bool { return true }
|
||||||
|
|
||||||
|
func TestIsTimeout(t *testing.T) {
|
||||||
|
if !IsTimeout(context.DeadlineExceeded) {
|
||||||
|
t.Fatal("context deadline should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := &url.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com",
|
||||||
|
Err: timeoutErr{},
|
||||||
|
}
|
||||||
|
if !IsTimeout(uerr) {
|
||||||
|
t.Fatal("url timeout error should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) {
|
||||||
|
t.Fatal("wrapped timeout should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsTimeout(errors.New("plain error")) {
|
||||||
|
t.Fatal("plain error must not be timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDNS(t *testing.T) {
|
||||||
|
dnsErr := &net.DNSError{
|
||||||
|
Err: "no such host",
|
||||||
|
Name: "example.invalid",
|
||||||
|
IsNotFound: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(dnsErr) {
|
||||||
|
t.Fatal("dns error should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) {
|
||||||
|
t.Fatal("wrapped dns error should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(errors.New("lookup example.invalid: no such host")) {
|
||||||
|
t.Fatal("lookup no such host should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsDNS(errors.New("connection reset by peer")) {
|
||||||
|
t.Fatal("non dns error should not be dns")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTLS(t *testing.T) {
|
||||||
|
tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"}
|
||||||
|
if !IsTLS(tlsErr) {
|
||||||
|
t.Fatal("tls record header error should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) {
|
||||||
|
t.Fatal("wrapped tls error should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(errors.New("x509: certificate signed by unknown authority")) {
|
||||||
|
t.Fatal("x509 error text should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(ErrNotTLS) {
|
||||||
|
t.Fatal("ErrNotTLS should be tls related")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsTLS(errors.New("plain error")) {
|
||||||
|
t.Fatal("plain error should not be tls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxy(t *testing.T) {
|
||||||
|
raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused")
|
||||||
|
if !IsProxy(raw) {
|
||||||
|
t.Fatal("proxyconnect error should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := &url.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com",
|
||||||
|
Err: raw,
|
||||||
|
}
|
||||||
|
if !IsProxy(uerr) {
|
||||||
|
t.Fatal("wrapped proxy error should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
opErr := &net.OpError{
|
||||||
|
Op: "proxyconnect",
|
||||||
|
Net: "tcp",
|
||||||
|
Err: errors.New("connect failed"),
|
||||||
|
}
|
||||||
|
if !IsProxy(opErr) {
|
||||||
|
t.Fatal("net.OpError proxyconnect should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) {
|
||||||
|
t.Fatal("non proxy dial error should not be proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
49
errors_kind_test.go
Normal file
49
errors_kind_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsCanceled(t *testing.T) {
|
||||||
|
if !IsCanceled(context.Canceled) {
|
||||||
|
t.Fatal("context canceled should be canceled")
|
||||||
|
}
|
||||||
|
if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) {
|
||||||
|
t.Fatal("wrapped context canceled should be canceled")
|
||||||
|
}
|
||||||
|
if IsCanceled(context.DeadlineExceeded) {
|
||||||
|
t.Fatal("deadline exceeded must not be canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyError(t *testing.T) {
|
||||||
|
dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true}
|
||||||
|
tlsErr := tls.RecordHeaderError{Msg: "bad tls record"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want ErrorKind
|
||||||
|
}{
|
||||||
|
{name: "nil", err: nil, want: ErrorKindNone},
|
||||||
|
{name: "canceled", err: context.Canceled, want: ErrorKindCanceled},
|
||||||
|
{name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy},
|
||||||
|
{name: "dns", err: dnsErr, want: ErrorKindDNS},
|
||||||
|
{name: "tls", err: tlsErr, want: ErrorKindTLS},
|
||||||
|
{name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout},
|
||||||
|
{name: "other", err: errors.New("boom"), want: ErrorKindOther},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ClassifyError(tt.err); got != tt.want {
|
||||||
|
t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"b612.me/starnet"
|
"b612.me/starnet"
|
||||||
@ -198,3 +199,58 @@ func ExampleWithTimeout() {
|
|||||||
fmt.Println(resp.StatusCode)
|
fmt.Println(resp.StatusCode)
|
||||||
// Output: 200
|
// Output: 200
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ExampleWithRetry() {
|
||||||
|
var hits int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.Get(server.URL,
|
||||||
|
starnet.WithRetry(2,
|
||||||
|
starnet.WithRetryBackoff(0, 0, 1),
|
||||||
|
starnet.WithRetryJitter(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
|
||||||
|
// Output: 200 3
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_SetRetry() {
|
||||||
|
var hits int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello"). // 可重放 body 才能安全重试
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryIdempotentOnly(false).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
|
||||||
|
// Output: 200 2
|
||||||
|
}
|
||||||
|
|||||||
21
options.go
21
options.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -12,9 +13,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// WithTimeout 设置请求总超时时间
|
// WithTimeout 设置请求总超时时间
|
||||||
// timeout > 0: 使用该超时
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
// timeout = 0: 使用 Client 默认超时
|
// timeout = 0: 不额外设置请求总超时
|
||||||
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0)
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||||
return func(r *Request) error {
|
return func(r *Request) error {
|
||||||
r.config.Network.Timeout = timeout
|
r.config.Network.Timeout = timeout
|
||||||
@ -371,9 +372,23 @@ func WithAutoFetch(auto bool) RequestOpt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if maxBytes < 0 {
|
||||||
|
return fmt.Errorf("max response body bytes must be >= 0")
|
||||||
|
}
|
||||||
|
r.config.MaxRespBodyBytes = maxBytes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithRawRequest 设置原始请求
|
// WithRawRequest 设置原始请求
|
||||||
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
||||||
return func(r *Request) error {
|
return func(r *Request) error {
|
||||||
|
if httpReq == nil {
|
||||||
|
return fmt.Errorf("httpReq cannot be nil")
|
||||||
|
}
|
||||||
r.httpReq = httpReq
|
r.httpReq = httpReq
|
||||||
r.doRaw = true
|
r.doRaw = true
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
522
ping.go
522
ping.go
@ -1,12 +1,31 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
icmpTypeEchoReplyV4 = 0
|
||||||
|
icmpTypeEchoRequestV4 = 8
|
||||||
|
icmpTypeEchoRequestV6 = 128
|
||||||
|
icmpTypeEchoReplyV6 = 129
|
||||||
|
|
||||||
|
icmpHeaderLen = 8
|
||||||
|
icmpReadBufSz = 1500
|
||||||
|
|
||||||
|
defaultPingAttemptTimeout = 2 * time.Second
|
||||||
|
defaultPingableCount = 3
|
||||||
|
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
|
||||||
|
)
|
||||||
|
|
||||||
type ICMP struct {
|
type ICMP struct {
|
||||||
Type uint8
|
Type uint8
|
||||||
Code uint8
|
Code uint8
|
||||||
@ -15,52 +34,126 @@ type ICMP struct {
|
|||||||
SequenceNum uint16
|
SequenceNum uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func getICMP(seq uint16) ICMP {
|
type pingSocketSpec struct {
|
||||||
|
network string
|
||||||
|
family int
|
||||||
|
requestType uint8
|
||||||
|
replyType uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingOptions controls ping probing behavior.
|
||||||
|
type PingOptions struct {
|
||||||
|
Count int // ping attempts for Pingable, default 3
|
||||||
|
Timeout time.Duration // per-attempt timeout, default 2s
|
||||||
|
Interval time.Duration // delay between attempts, default 0
|
||||||
|
Deadline time.Time // overall deadline for Pingable/PingWithContext
|
||||||
|
PreferIPv4 bool // prefer IPv4 targets
|
||||||
|
PreferIPv6 bool // prefer IPv6 targets
|
||||||
|
SourceIP net.IP // optional source IP for raw socket bind
|
||||||
|
PayloadSize int // ICMP payload bytes, default 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type PingResult struct {
|
||||||
|
Duration time.Duration
|
||||||
|
RecvCount int
|
||||||
|
RemoteIP string
|
||||||
|
}
|
||||||
|
|
||||||
|
var pingIdentifierSeed uint32
|
||||||
|
|
||||||
|
func nextPingIdentifier() uint16 {
|
||||||
|
pid := uint32(os.Getpid() & 0xffff)
|
||||||
|
n := atomic.AddUint32(&pingIdentifierSeed, 1)
|
||||||
|
return uint16((pid + n) & 0xffff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pingPayload(size int) []byte {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := make([]byte, size)
|
||||||
|
for i := 0; i < len(payload); i++ {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
||||||
icmp := ICMP{
|
icmp := ICMP{
|
||||||
Type: 8,
|
Type: typ,
|
||||||
Code: 0,
|
Code: 0,
|
||||||
CheckSum: 0,
|
CheckSum: 0,
|
||||||
Identifier: 0,
|
Identifier: identifier,
|
||||||
SequenceNum: seq,
|
SequenceNum: seq,
|
||||||
}
|
}
|
||||||
var buffer bytes.Buffer
|
buf := marshalICMPPacket(icmp, payload)
|
||||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
icmp.CheckSum = checkSum(buf)
|
||||||
icmp.CheckSum = checkSum(buffer.Bytes())
|
|
||||||
buffer.Reset()
|
|
||||||
|
|
||||||
return icmp
|
return icmp
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
|
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
|
||||||
var res PingResult
|
var res PingResult
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return res, wrapError(err, "ping context done")
|
||||||
|
}
|
||||||
|
if destAddr == nil || destAddr.IP == nil {
|
||||||
|
return res, fmt.Errorf("destination ip is nil")
|
||||||
|
}
|
||||||
res.RemoteIP = destAddr.String()
|
res.RemoteIP = destAddr.String()
|
||||||
conn, err := net.DialIP("ip:icmp", nil, destAddr)
|
|
||||||
|
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
|
||||||
|
|
||||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
conn, err := net.DialIP(spec.network, localAddr, destAddr)
|
||||||
return res, err
|
if err != nil {
|
||||||
|
return res, normalizePingDialError(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
packet := marshalICMPPacket(icmp, payload)
|
||||||
|
if _, err := conn.Write(packet); err != nil {
|
||||||
|
return res, wrapError(err, "ping write request")
|
||||||
}
|
}
|
||||||
|
|
||||||
tStart := time.Now()
|
tStart := time.Now()
|
||||||
|
deadline := tStart.Add(timeout)
|
||||||
conn.SetReadDeadline((time.Now().Add(timeout)))
|
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
|
||||||
|
deadline = d
|
||||||
recv := make([]byte, 1024)
|
}
|
||||||
res.RecvCount, err = conn.Read(recv)
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||||
|
return res, wrapError(err, "ping set read deadline")
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tEnd := time.Now()
|
doneCh := make(chan struct{})
|
||||||
res.Duration = tEnd.Sub(tStart)
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = conn.SetReadDeadline(time.Now())
|
||||||
|
case <-doneCh:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(doneCh)
|
||||||
|
|
||||||
return res, err
|
recv := make([]byte, icmpReadBufSz)
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(recv)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return res, wrapError(ctx.Err(), "ping context done")
|
||||||
|
}
|
||||||
|
return res, wrapError(err, "ping read reply")
|
||||||
|
}
|
||||||
|
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
|
||||||
|
res.RecvCount = n
|
||||||
|
res.Duration = time.Since(tStart)
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSum(data []byte) uint16 {
|
func checkSum(data []byte) uint16 {
|
||||||
@ -75,36 +168,375 @@ func checkSum(data []byte) uint16 {
|
|||||||
length -= 2
|
length -= 2
|
||||||
}
|
}
|
||||||
if length > 0 {
|
if length > 0 {
|
||||||
sum += uint32(data[index])
|
sum += uint32(data[index]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
}
|
}
|
||||||
sum += (sum >> 16)
|
|
||||||
|
|
||||||
return uint16(^sum)
|
return uint16(^sum)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PingResult struct {
|
func marshalICMP(icmp ICMP) []byte {
|
||||||
Duration time.Duration
|
return marshalICMPPacket(icmp, nil)
|
||||||
RecvCount int
|
|
||||||
RemoteIP string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
|
||||||
var res PingResult
|
buf := make([]byte, icmpHeaderLen+len(payload))
|
||||||
ipAddr, err := net.ResolveIPAddr("ip", ip)
|
buf[0] = icmp.Type
|
||||||
if err != nil {
|
buf[1] = icmp.Code
|
||||||
return res, err
|
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
|
||||||
}
|
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
|
||||||
icmp := getICMP(uint16(seq))
|
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
|
||||||
return sendICMPRequest(icmp, ipAddr, timeout)
|
copy(buf[icmpHeaderLen:], payload)
|
||||||
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
||||||
for i := 0; i < retryLimit; i++ {
|
for _, off := range candidateICMPOffsets(packet, family) {
|
||||||
_, err := Ping(ip, 29, timeout)
|
if off < 0 || off+icmpHeaderLen > len(packet) {
|
||||||
if err != nil {
|
continue
|
||||||
|
}
|
||||||
|
if packet[off] != expectedType || packet[off+1] != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func candidateICMPOffsets(packet []byte, family int) []int {
|
||||||
|
offsets := []int{0}
|
||||||
|
if len(packet) == 0 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
ver := packet[0] >> 4
|
||||||
|
if ver == 4 && len(packet) >= 20 {
|
||||||
|
ihl := int(packet[0]&0x0f) * 4
|
||||||
|
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
|
||||||
|
offsets = append(offsets, ihl)
|
||||||
|
}
|
||||||
|
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
|
||||||
|
if family == 4 && len(packet) >= 20+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 20)
|
||||||
|
}
|
||||||
|
if family == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dedupOffsets(offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupOffsets(offsets []int) []int {
|
||||||
|
if len(offsets) <= 1 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
m := make(map[int]struct{}, len(offsets))
|
||||||
|
out := make([]int, 0, len(offsets))
|
||||||
|
for _, off := range offsets {
|
||||||
|
if _, ok := m[off]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[off] = struct{}{}
|
||||||
|
out = append(out, off)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
|
||||||
|
if ip == nil {
|
||||||
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
return pingSocketSpec{
|
||||||
|
network: "ip4:icmp",
|
||||||
|
family: 4,
|
||||||
|
requestType: icmpTypeEchoRequestV4,
|
||||||
|
replyType: icmpTypeEchoReplyV4,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip16 := ip.To16(); ip16 != nil {
|
||||||
|
return pingSocketSpec{
|
||||||
|
network: "ip6:ipv6-icmp",
|
||||||
|
family: 6,
|
||||||
|
requestType: icmpTypeEchoRequestV6,
|
||||||
|
replyType: icmpTypeEchoReplyV6,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
|
||||||
|
if sourceIP == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if sourceIP.To16() == nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
|
||||||
|
}
|
||||||
|
if family == 4 && sourceIP.To4() == nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
|
||||||
|
}
|
||||||
|
if family == 6 && sourceIP.To4() != nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
|
||||||
|
}
|
||||||
|
return &net.IPAddr{IP: sourceIP}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
||||||
|
if parsed := net.ParseIP(host); parsed != nil {
|
||||||
|
return []*net.IPAddr{{IP: parsed}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var targets []*net.IPAddr
|
||||||
|
var err4 error
|
||||||
|
var err6 error
|
||||||
|
|
||||||
|
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
|
||||||
|
targets = append(targets, ip4)
|
||||||
|
} else {
|
||||||
|
err4 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
|
||||||
|
targets = append(targets, ip6)
|
||||||
|
} else {
|
||||||
|
err6 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(targets) > 0 {
|
||||||
|
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err4 != nil {
|
||||||
|
return nil, err4
|
||||||
|
}
|
||||||
|
if err6 != nil {
|
||||||
|
return nil, err6
|
||||||
|
}
|
||||||
|
return nil, ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
||||||
|
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
|
||||||
|
return targets
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := make([]*net.IPAddr, 0, len(targets))
|
||||||
|
if preferIPv4 {
|
||||||
|
for _, t := range targets {
|
||||||
|
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, t := range targets {
|
||||||
|
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range targets {
|
||||||
|
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, t := range targets {
|
||||||
|
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePingDialError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
if errors.Is(err, os.ErrPermission) ||
|
||||||
|
strings.Contains(msg, "operation not permitted") ||
|
||||||
|
strings.Contains(msg, "permission denied") {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(msg, "unknown network") ||
|
||||||
|
strings.Contains(msg, "protocol not available") ||
|
||||||
|
strings.Contains(msg, "address family not supported by protocol") ||
|
||||||
|
strings.Contains(msg, "socket type not supported") {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapError(err, "ping dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
|
||||||
|
out := PingOptions{
|
||||||
|
Count: defaultCount,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Interval: 0,
|
||||||
|
PayloadSize: 0,
|
||||||
|
}
|
||||||
|
if opts != nil {
|
||||||
|
out = *opts
|
||||||
|
if out.Count == 0 {
|
||||||
|
out.Count = defaultCount
|
||||||
|
}
|
||||||
|
if out.Timeout == 0 {
|
||||||
|
out.Timeout = defaultTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if out.Count < 0 {
|
||||||
|
return out, fmt.Errorf("ping count must be >= 0")
|
||||||
|
}
|
||||||
|
if out.Timeout <= 0 {
|
||||||
|
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
|
||||||
|
}
|
||||||
|
if out.Interval < 0 {
|
||||||
|
return out, fmt.Errorf("ping interval must be >= 0")
|
||||||
|
}
|
||||||
|
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
|
||||||
|
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
|
||||||
|
}
|
||||||
|
if out.SourceIP != nil && out.SourceIP.To16() == nil {
|
||||||
|
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
|
||||||
|
var res PingResult
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return res, wrapError(err, "ping context done")
|
||||||
|
}
|
||||||
|
|
||||||
|
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return res, wrapError(err, "resolve ping target")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := pingPayload(opts.PayloadSize)
|
||||||
|
var lastErr error
|
||||||
|
for _, target := range targets {
|
||||||
|
spec, err := socketSpecForIP(target.IP)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
|
||||||
|
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 权限问题通常与地址族无关,继续重试意义不大。
|
||||||
|
if errors.Is(err, ErrPingPermissionDenied) {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return res, wrapError(lastErr, "ping all resolved targets failed")
|
||||||
|
}
|
||||||
|
return res, ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingWithContext sends one ICMP echo request with context cancel support.
|
||||||
|
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
|
||||||
|
opts, err := normalizePingOptions(&PingOptions{
|
||||||
|
Count: 1,
|
||||||
|
Timeout: timeout,
|
||||||
|
}, 1, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return PingResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.Deadline.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
return pingOnceWithOptions(ctx, host, seq, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping sends one ICMP echo request.
|
||||||
|
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
||||||
|
return PingWithContext(context.Background(), ip, seq, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pingable checks host reachability with retry options.
|
||||||
|
func Pingable(host string, opts *PingOptions) (bool, error) {
|
||||||
|
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if !cfg.Deadline.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for i := 0; i < cfg.Count; i++ {
|
||||||
|
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if i < cfg.Count-1 && cfg.Interval > 0 {
|
||||||
|
timer := time.NewTimer(cfg.Interval)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return false, wrapError(ctx.Err(), "pingable context done")
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
return false, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIpPingable keeps backward-compatible bool-only behavior.
|
||||||
|
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
||||||
|
if retryLimit <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ok, _ := Pingable(ip, &PingOptions{
|
||||||
|
Count: retryLimit,
|
||||||
|
Timeout: timeout,
|
||||||
|
})
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|||||||
214
ping_logic_test.go
Normal file
214
ping_logic_test.go
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildICMPPacket(typ uint8, identifier, seq uint16) []byte {
|
||||||
|
icmp := ICMP{
|
||||||
|
Type: typ,
|
||||||
|
Code: 0,
|
||||||
|
CheckSum: 0,
|
||||||
|
Identifier: identifier,
|
||||||
|
SequenceNum: seq,
|
||||||
|
}
|
||||||
|
buf := marshalICMPPacket(icmp, nil)
|
||||||
|
cs := checkSum(buf)
|
||||||
|
buf[2] = byte(cs >> 8)
|
||||||
|
buf[3] = byte(cs)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextPingIdentifierChanges(t *testing.T) {
|
||||||
|
id1 := nextPingIdentifier()
|
||||||
|
id2 := nextPingIdentifier()
|
||||||
|
if id1 == id2 {
|
||||||
|
t.Fatalf("identifier should change between calls: %d == %d", id1, id2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv4(t *testing.T) {
|
||||||
|
identifier := uint16(0x1234)
|
||||||
|
seq := uint16(0x0102)
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv4 echo reply to match")
|
||||||
|
}
|
||||||
|
if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) {
|
||||||
|
t.Fatal("mismatched sequence should not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) {
|
||||||
|
identifier := uint16(0x1111)
|
||||||
|
seq := uint16(0x2222)
|
||||||
|
|
||||||
|
ipHeader := make([]byte, 20)
|
||||||
|
ipHeader[0] = 0x45 // version=4, ihl=5
|
||||||
|
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
|
||||||
|
packet := append(ipHeader, reply...)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv4 packet with header to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) {
|
||||||
|
identifier := uint16(0xabcd)
|
||||||
|
seq := uint16(0x00ff)
|
||||||
|
|
||||||
|
ipv6Header := make([]byte, 40)
|
||||||
|
ipv6Header[0] = 0x60 // version=6
|
||||||
|
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq)
|
||||||
|
packet := append(ipv6Header, reply...)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv6 packet with header to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingInvalidTimeout(t *testing.T) {
|
||||||
|
_, err := Ping("127.0.0.1", 1, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-positive timeout")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrPingInvalidTimeout) {
|
||||||
|
t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsIPPingableInvalidRetry(t *testing.T) {
|
||||||
|
if IsIpPingable("127.0.0.1", time.Millisecond, 0) {
|
||||||
|
t.Fatal("retryLimit=0 should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSocketSpecForIP(t *testing.T) {
|
||||||
|
v4, err := socketSpecForIP(net.ParseIP("127.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v4 error: %v", err)
|
||||||
|
}
|
||||||
|
if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 {
|
||||||
|
t.Fatalf("unexpected v4 spec: %+v", v4)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6, err := socketSpecForIP(net.ParseIP("::1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v6 error: %v", err)
|
||||||
|
}
|
||||||
|
if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 {
|
||||||
|
t.Fatalf("unexpected v6 spec: %+v", v6)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = socketSpecForIP(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nil ip")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrInvalidIP) {
|
||||||
|
t.Fatalf("expected ErrInvalidIP, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePingTargetsLiteral(t *testing.T) {
|
||||||
|
v4, err := resolvePingTargets("127.0.0.1", false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v4 resolve error: %v", err)
|
||||||
|
}
|
||||||
|
if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil {
|
||||||
|
t.Fatalf("unexpected v4 targets: %+v", v4)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6, err := resolvePingTargets("::1", false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v6 resolve error: %v", err)
|
||||||
|
}
|
||||||
|
if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil {
|
||||||
|
t.Fatalf("unexpected v6 targets: %+v", v6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePingDialError(t *testing.T) {
|
||||||
|
perr := normalizePingDialError(errors.New("socket: operation not permitted"))
|
||||||
|
if !errors.Is(perr, ErrPingPermissionDenied) {
|
||||||
|
t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr)
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp"))
|
||||||
|
if !errors.Is(uerr, ErrPingProtocolUnsupported) {
|
||||||
|
t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderPingTargets(t *testing.T) {
|
||||||
|
targets := []*net.IPAddr{
|
||||||
|
{IP: net.ParseIP("::1")},
|
||||||
|
{IP: net.ParseIP("127.0.0.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v4First := orderPingTargets(targets, true, false)
|
||||||
|
if v4First[0].IP.To4() == nil {
|
||||||
|
t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6First := orderPingTargets(targets, false, true)
|
||||||
|
if v6First[0].IP.To4() != nil {
|
||||||
|
t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePingOptions(t *testing.T) {
|
||||||
|
opts, err := normalizePingOptions(nil, 3, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if opts.Count != 3 || opts.Timeout != 2*time.Second {
|
||||||
|
t.Fatalf("unexpected defaults: %+v", opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative count")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for too large payload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingWithContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected canceled error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context.Canceled, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingableInvalidOptions(t *testing.T) {
|
||||||
|
_, err := Pingable("127.0.0.1", &PingOptions{Count: -1})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected invalid count error")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = Pingable("127.0.0.1", &PingOptions{Interval: -1})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected invalid interval error")
|
||||||
|
}
|
||||||
|
}
|
||||||
58
request.go
58
request.go
@ -12,6 +12,7 @@ import (
|
|||||||
type Request struct {
|
type Request struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
execCtx context.Context // 执行时的 context(注入了配置)
|
execCtx context.Context // 执行时的 context(注入了配置)
|
||||||
|
cancel context.CancelFunc
|
||||||
url string
|
url string
|
||||||
method string
|
method string
|
||||||
err error // 累积的错误
|
err error // 累积的错误
|
||||||
@ -20,6 +21,7 @@ type Request struct {
|
|||||||
client *Client
|
client *Client
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
httpReq *http.Request
|
httpReq *http.Request
|
||||||
|
retry *retryPolicy
|
||||||
|
|
||||||
applied bool // 是否已应用配置
|
applied bool // 是否已应用配置
|
||||||
doRaw bool // 是否使用原始请求(不修改)
|
doRaw bool // 是否使用原始请求(不修改)
|
||||||
@ -160,6 +162,7 @@ func (r *Request) Clone() *Request {
|
|||||||
config: r.config.Clone(),
|
config: r.config.Clone(),
|
||||||
client: r.client,
|
client: r.client,
|
||||||
httpClient: r.httpClient,
|
httpClient: r.httpClient,
|
||||||
|
retry: cloneRetryPolicy(r.retry),
|
||||||
applied: false, // 重置应用状态
|
applied: false, // 重置应用状态
|
||||||
doRaw: r.doRaw,
|
doRaw: r.doRaw,
|
||||||
autoFetch: r.autoFetch,
|
autoFetch: r.autoFetch,
|
||||||
@ -318,6 +321,14 @@ func (r *Request) Do() (*Response, error) {
|
|||||||
return nil, r.err
|
return nil, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.hasRetryPolicy() {
|
||||||
|
return r.doWithRetry()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) doOnce() (*Response, error) {
|
||||||
// 准备请求
|
// 准备请求
|
||||||
if err := r.prepare(); err != nil {
|
if err := r.prepare(); err != nil {
|
||||||
return nil, wrapError(err, "prepare request")
|
return nil, wrapError(err, "prepare request")
|
||||||
@ -326,6 +337,10 @@ func (r *Request) Do() (*Response, error) {
|
|||||||
// 执行请求
|
// 执行请求
|
||||||
httpResp, err := r.httpClient.Do(r.httpReq)
|
httpResp, err := r.httpClient.Do(r.httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
return &Response{
|
return &Response{
|
||||||
Response: &http.Response{},
|
Response: &http.Response{},
|
||||||
request: r,
|
request: r,
|
||||||
@ -334,19 +349,33 @@ func (r *Request) Do() (*Response, error) {
|
|||||||
}, wrapError(err, "do request")
|
}, wrapError(err, "do request")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawBody := httpResp.Body
|
||||||
|
if r.cancel != nil {
|
||||||
|
rawBody = &cancelReadCloser{
|
||||||
|
ReadCloser: httpResp.Body,
|
||||||
|
cancel: r.cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 创建响应
|
// 创建响应
|
||||||
resp := &Response{
|
resp := &Response{
|
||||||
Response: httpResp,
|
Response: httpResp,
|
||||||
request: r,
|
request: r,
|
||||||
httpClient: r.httpClient,
|
httpClient: r.httpClient,
|
||||||
|
cancel: r.cancel,
|
||||||
body: &Body{
|
body: &Body{
|
||||||
raw: httpResp.Body,
|
raw: rawBody,
|
||||||
|
maxBytes: r.config.MaxRespBodyBytes,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
r.cancel = nil
|
||||||
|
|
||||||
// 自动获取响应体
|
// 自动获取响应体
|
||||||
if r.autoFetch {
|
if r.autoFetch {
|
||||||
resp.body.readAll()
|
if err := resp.body.readAll(); err != nil {
|
||||||
|
_ = resp.Close()
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@ -371,3 +400,28 @@ func (r *Request) Put() (*Response, error) {
|
|||||||
func (r *Request) Delete() (*Response, error) {
|
func (r *Request) Delete() (*Response, error) {
|
||||||
return r.SetMethod(http.MethodDelete).Do()
|
return r.SetMethod(http.MethodDelete).Do()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Head 发送 HEAD 请求
|
||||||
|
func (r *Request) Head() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodHead).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch 发送 PATCH 请求
|
||||||
|
func (r *Request) Patch() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodPatch).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options 发送 OPTIONS 请求
|
||||||
|
func (r *Request) Options() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodOptions).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace 发送 TRACE 请求
|
||||||
|
func (r *Request) Trace() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodTrace).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 发送 CONNECT 请求
|
||||||
|
func (r *Request) Connect() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodConnect).Do()
|
||||||
|
}
|
||||||
|
|||||||
155
request_body.go
155
request_body.go
@ -2,7 +2,9 @@ package starnet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -65,7 +67,7 @@ func (r *Request) SetFormData(data map[string][]string) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Body.FormData = data
|
r.config.Body.FormData = cloneStringMapSlice(data)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,69 +347,81 @@ func (r *Request) prepare() error {
|
|||||||
return err // ← 失败时不设置 applied
|
return err // ← 失败时不设置 applied
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 原始模式不修改请求内容
|
|
||||||
if r.doRaw {
|
if r.httpReq == nil {
|
||||||
r.applied = true
|
return fmt.Errorf("http request is nil")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用查询参数
|
// 原始模式不修改请求内容
|
||||||
if len(r.config.Queries) > 0 {
|
if !r.doRaw {
|
||||||
q := r.httpReq.URL.Query()
|
// 应用查询参数
|
||||||
for k, values := range r.config.Queries {
|
if len(r.config.Queries) > 0 {
|
||||||
|
q := r.httpReq.URL.Query()
|
||||||
|
for k, values := range r.config.Queries {
|
||||||
|
for _, v := range values {
|
||||||
|
q.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.httpReq.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用 Headers
|
||||||
|
for k, values := range r.config.Headers {
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
q.Add(k, v)
|
r.httpReq.Header.Add(k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.httpReq.URL.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Headers
|
// 应用 Cookies
|
||||||
for k, values := range r.config.Headers {
|
for _, cookie := range r.config.Cookies {
|
||||||
for _, v := range values {
|
r.httpReq.AddCookie(cookie)
|
||||||
r.httpReq.Header.Add(k, v)
|
}
|
||||||
|
|
||||||
|
// 应用 Basic Auth
|
||||||
|
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
||||||
|
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用请求体
|
||||||
|
if err := r.applyBody(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用 Content-Length
|
||||||
|
if r.config.ContentLength > 0 {
|
||||||
|
r.httpReq.ContentLength = r.config.ContentLength
|
||||||
|
} else if r.config.ContentLength < 0 {
|
||||||
|
r.httpReq.ContentLength = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动计算 Content-Length
|
||||||
|
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
||||||
|
data, err := io.ReadAll(r.httpReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "read body for content length")
|
||||||
|
}
|
||||||
|
r.httpReq.ContentLength = int64(len(data))
|
||||||
|
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置 TLS ServerName(如果有 TLS Config)
|
||||||
|
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
|
||||||
|
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用 Cookies
|
execCtx := r.ctx
|
||||||
for _, cookie := range r.config.Cookies {
|
if !r.doRaw {
|
||||||
r.httpReq.AddCookie(cookie)
|
// raw 模式下不注入请求级网络配置,只应用 context/超时。
|
||||||
|
execCtx = injectRequestConfig(execCtx, r.config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用 Basic Auth
|
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
|
||||||
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
if r.config.Network.Timeout > 0 {
|
||||||
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用请求体
|
r.execCtx = execCtx
|
||||||
if err := r.applyBody(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 应用 Content-Length
|
|
||||||
if r.config.ContentLength > 0 {
|
|
||||||
r.httpReq.ContentLength = r.config.ContentLength
|
|
||||||
} else if r.config.ContentLength < 0 {
|
|
||||||
r.httpReq.ContentLength = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 自动计算 Content-Length
|
|
||||||
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
|
||||||
data, err := io.ReadAll(r.httpReq.Body)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err, "read body for content length")
|
|
||||||
}
|
|
||||||
r.httpReq.ContentLength = int64(len(data))
|
|
||||||
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置 TLS ServerName(如果有 TLS Config)
|
|
||||||
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
|
|
||||||
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注入配置到 context
|
|
||||||
r.execCtx = injectRequestConfig(r.ctx, r.config)
|
|
||||||
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
||||||
|
|
||||||
r.applied = true
|
r.applied = true
|
||||||
@ -416,50 +430,19 @@ func (r *Request) prepare() error {
|
|||||||
|
|
||||||
// buildHTTPClient 构建 HTTP Client
|
// buildHTTPClient 构建 HTTP Client
|
||||||
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
||||||
applyTimeoutOverride := func(base *http.Client) *http.Client {
|
|
||||||
// 没有 base 时兜底
|
|
||||||
if base == nil {
|
|
||||||
base = &http.Client{}
|
|
||||||
}
|
|
||||||
|
|
||||||
rt := r.config.Network.Timeout
|
|
||||||
|
|
||||||
// 语义:
|
|
||||||
// rt < 0 : 本次请求禁用超时(Timeout = 0)
|
|
||||||
// rt = 0 : 沿用 base.Timeout
|
|
||||||
// rt > 0 : 本次请求超时覆盖
|
|
||||||
if rt == 0 {
|
|
||||||
return base
|
|
||||||
}
|
|
||||||
|
|
||||||
clone := &http.Client{
|
|
||||||
Transport: base.Transport,
|
|
||||||
CheckRedirect: base.CheckRedirect,
|
|
||||||
Jar: base.Jar,
|
|
||||||
}
|
|
||||||
|
|
||||||
if rt < 0 {
|
|
||||||
clone.Timeout = 0
|
|
||||||
} else {
|
|
||||||
clone.Timeout = rt
|
|
||||||
}
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优先使用请求关联的 Client
|
// 优先使用请求关联的 Client
|
||||||
if r.client != nil {
|
if r.client != nil {
|
||||||
return applyTimeoutOverride(r.client.HTTPClient()), nil
|
return r.client.HTTPClient(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 自定义 Transport
|
// 自定义 Transport
|
||||||
if r.config.CustomTransport && r.config.Transport != nil {
|
if r.config.CustomTransport && r.config.Transport != nil {
|
||||||
base := &http.Client{
|
return &http.Client{
|
||||||
Transport: &Transport{base: r.config.Transport},
|
Transport: &Transport{base: r.config.Transport},
|
||||||
Timeout: 0,
|
Timeout: 0,
|
||||||
}
|
}, nil
|
||||||
return applyTimeoutOverride(base), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 默认全局 client
|
// 默认全局 client
|
||||||
return applyTimeoutOverride(DefaultHTTPClient()), nil
|
return DefaultHTTPClient(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,9 +10,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// SetTimeout 设置请求总超时时间
|
// SetTimeout 设置请求总超时时间
|
||||||
// timeout > 0: 使用该超时
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
// timeout = 0: 使用 Client 默认超时
|
// timeout = 0: 不额外设置请求总超时
|
||||||
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0)
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return r
|
return r
|
||||||
@ -194,6 +194,19 @@ func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if maxBytes < 0 {
|
||||||
|
r.err = fmt.Errorf("max response body bytes must be >= 0")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.MaxRespBodyBytes = maxBytes
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
// AddQuery 添加查询参数
|
// AddQuery 添加查询参数
|
||||||
func (r *Request) AddQuery(key, value string) *Request {
|
func (r *Request) AddQuery(key, value string) *Request {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
@ -217,7 +230,7 @@ func (r *Request) SetQueries(queries map[string][]string) *Request {
|
|||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Queries = queries
|
r.config.Queries = cloneStringMapSlice(queries)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ func (r *Request) SetHeaders(headers http.Header) *Request {
|
|||||||
if r.doRaw {
|
if r.doRaw {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r.config.Headers = headers
|
r.config.Headers = cloneHeader(headers)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
43
request_methods_ext_test.go
Normal file
43
request_methods_ext_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestConvenienceMethods(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Method", r.Method)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
do func(r *Request) (*Response, error)
|
||||||
|
}{
|
||||||
|
{name: "Head", method: http.MethodHead, do: (*Request).Head},
|
||||||
|
{name: "Patch", method: http.MethodPatch, do: (*Request).Patch},
|
||||||
|
{name: "Options", method: http.MethodOptions, do: (*Request).Options},
|
||||||
|
{name: "Trace", method: http.MethodTrace, do: (*Request).Trace},
|
||||||
|
{name: "Connect", method: http.MethodConnect, do: (*Request).Connect},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := NewSimpleRequest(server.URL, http.MethodGet)
|
||||||
|
resp, err := tt.do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Method"); got != tt.method {
|
||||||
|
t.Fatalf("method=%s want=%s", got, tt.method)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
36
response.go
36
response.go
@ -13,6 +13,7 @@ type Response struct {
|
|||||||
*http.Response
|
*http.Response
|
||||||
request *Request
|
request *Request
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
cancel func()
|
||||||
body *Body
|
body *Body
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,9 +22,26 @@ type Body struct {
|
|||||||
raw io.ReadCloser
|
raw io.ReadCloser
|
||||||
data []byte
|
data []byte
|
||||||
consumed bool
|
consumed bool
|
||||||
|
maxBytes int64
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cancelReadCloser struct {
|
||||||
|
io.ReadCloser
|
||||||
|
cancel func()
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cancelReadCloser) Close() error {
|
||||||
|
err := c.ReadCloser.Close()
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.cancel != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Request 获取原始请求
|
// Request 获取原始请求
|
||||||
func (r *Response) Request() *Request {
|
func (r *Response) Request() *Request {
|
||||||
return r.request
|
return r.request
|
||||||
@ -42,6 +60,10 @@ func (r *Response) Close() error {
|
|||||||
if r.body != nil && r.body.raw != nil {
|
if r.body != nil && r.body.raw != nil {
|
||||||
return r.body.raw.Close()
|
return r.body.raw.Close()
|
||||||
}
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,14 +92,24 @@ func (b *Body) readAll() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := io.ReadAll(b.raw)
|
reader := io.Reader(b.raw)
|
||||||
|
if b.maxBytes > 0 {
|
||||||
|
reader = io.LimitReader(b.raw, b.maxBytes+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(err, "read response body")
|
return wrapError(err, "read response body")
|
||||||
}
|
}
|
||||||
|
if b.maxBytes > 0 && int64(len(data)) > b.maxBytes {
|
||||||
|
b.consumed = true
|
||||||
|
_ = b.raw.Close()
|
||||||
|
return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
b.data = data
|
b.data = data
|
||||||
b.consumed = true
|
b.consumed = true
|
||||||
b.raw.Close()
|
_ = b.raw.Close()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
84
response_limit_test.go
Normal file
84
response_limit_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithMaxRespBodyBytes(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("123456"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL, WithMaxRespBodyBytes(4))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected request error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
_, err = resp.Body().Bytes()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body too large error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRespBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytes(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("1234"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetMaxRespBodyBytes(4).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected read error: %v", err)
|
||||||
|
}
|
||||||
|
if body != "1234" {
|
||||||
|
t.Fatalf("body=%q want=1234", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("123456"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetAutoFetch(true).
|
||||||
|
SetMaxRespBodyBytes(4).
|
||||||
|
Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body too large error with auto fetch")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRespBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytesInvalid(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error for negative max bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithMaxRespBodyBytesInvalid(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative max bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
423
retry.go
Normal file
423
retry.go
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RetryOpt func(*retryPolicy) error
|
||||||
|
|
||||||
|
type retryPolicy struct {
|
||||||
|
maxRetries int
|
||||||
|
baseDelay time.Duration
|
||||||
|
maxDelay time.Duration
|
||||||
|
factor float64
|
||||||
|
jitter float64
|
||||||
|
idempotentOnly bool
|
||||||
|
statuses map[int]struct{}
|
||||||
|
onError func(error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := &retryPolicy{
|
||||||
|
maxRetries: p.maxRetries,
|
||||||
|
baseDelay: p.baseDelay,
|
||||||
|
maxDelay: p.maxDelay,
|
||||||
|
factor: p.factor,
|
||||||
|
jitter: p.jitter,
|
||||||
|
idempotentOnly: p.idempotentOnly,
|
||||||
|
onError: p.onError,
|
||||||
|
}
|
||||||
|
if p.statuses != nil {
|
||||||
|
cloned.statuses = make(map[int]struct{}, len(p.statuses))
|
||||||
|
for code := range p.statuses {
|
||||||
|
cloned.statuses[code] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultRetryPolicy(max int) *retryPolicy {
|
||||||
|
return &retryPolicy{
|
||||||
|
maxRetries: max,
|
||||||
|
baseDelay: 100 * time.Millisecond,
|
||||||
|
maxDelay: 2 * time.Second,
|
||||||
|
factor: 2.0,
|
||||||
|
jitter: 0.1,
|
||||||
|
idempotentOnly: true,
|
||||||
|
statuses: map[int]struct{}{
|
||||||
|
http.StatusRequestTimeout: {},
|
||||||
|
http.StatusTooEarly: {},
|
||||||
|
http.StatusTooManyRequests: {},
|
||||||
|
http.StatusInternalServerError: {},
|
||||||
|
http.StatusBadGateway: {},
|
||||||
|
http.StatusServiceUnavailable: {},
|
||||||
|
http.StatusGatewayTimeout: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
|
||||||
|
if max < 0 {
|
||||||
|
return nil, fmt.Errorf("max retry must be >= 0")
|
||||||
|
}
|
||||||
|
if max == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := defaultRetryPolicy(max)
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := opt(policy); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
policy, err := buildRetryPolicy(max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.retry = policy
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
policy, err := buildRetryPolicy(max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
r.err = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.retry = policy
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) DisableRetry() *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.retry = nil
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if opt == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.retry == nil {
|
||||||
|
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if err := opt(r.retry); err != nil {
|
||||||
|
r.err = err
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryJitter(ratio float64) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryJitter(ratio))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryStatuses(codes ...int) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryStatuses(codes...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryOnError(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
if base < 0 {
|
||||||
|
return fmt.Errorf("retry base delay must be >= 0")
|
||||||
|
}
|
||||||
|
if max < 0 {
|
||||||
|
return fmt.Errorf("retry max delay must be >= 0")
|
||||||
|
}
|
||||||
|
if factor <= 0 {
|
||||||
|
return fmt.Errorf("retry factor must be > 0")
|
||||||
|
}
|
||||||
|
p.baseDelay = base
|
||||||
|
p.maxDelay = max
|
||||||
|
p.factor = factor
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryJitter(ratio float64) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
if ratio < 0 || ratio > 1 {
|
||||||
|
return fmt.Errorf("retry jitter ratio must be in [0,1]")
|
||||||
|
}
|
||||||
|
p.jitter = ratio
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryStatuses(codes ...int) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
statuses := make(map[int]struct{}, len(codes))
|
||||||
|
for _, code := range codes {
|
||||||
|
if code < 100 || code > 999 {
|
||||||
|
return fmt.Errorf("invalid retry status code: %d", code)
|
||||||
|
}
|
||||||
|
statuses[code] = struct{}{}
|
||||||
|
}
|
||||||
|
p.statuses = statuses
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
p.idempotentOnly = enabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryOnError(fn func(error) bool) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
p.onError = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) hasRetryPolicy() bool {
|
||||||
|
return r.retry != nil && r.retry.maxRetries > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) doWithRetry() (*Response, error) {
|
||||||
|
policy := cloneRetryPolicy(r.retry)
|
||||||
|
if policy == nil || policy.maxRetries <= 0 {
|
||||||
|
return r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !policy.canRetryRequest(r) {
|
||||||
|
return r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
retryCtx := r.ctx
|
||||||
|
retryCancel := func() {}
|
||||||
|
if r.config.Network.Timeout > 0 {
|
||||||
|
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
|
||||||
|
}
|
||||||
|
defer retryCancel()
|
||||||
|
|
||||||
|
maxAttempts := policy.maxRetries + 1
|
||||||
|
var lastResp *Response
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
attemptReq, err := r.newRetryAttempt(retryCtx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "build retry attempt")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := attemptReq.doOnce()
|
||||||
|
if resp != nil {
|
||||||
|
resp.request = r
|
||||||
|
}
|
||||||
|
|
||||||
|
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lastResp = resp
|
||||||
|
lastErr = err
|
||||||
|
if lastResp != nil {
|
||||||
|
_ = lastResp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := policy.nextDelay(attempt)
|
||||||
|
if delay <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
select {
|
||||||
|
case <-retryCtx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return lastResp, wrapError(retryCtx.Err(), "retry context done")
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastResp, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
|
||||||
|
attempt := r.Clone()
|
||||||
|
attempt.retry = nil
|
||||||
|
attempt.cancel = nil
|
||||||
|
attempt.applied = false
|
||||||
|
attempt.execCtx = nil
|
||||||
|
attempt.ctx = ctx
|
||||||
|
|
||||||
|
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
|
||||||
|
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
|
||||||
|
attempt.config.Network.Timeout = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if !attempt.doRaw {
|
||||||
|
attempt.httpReq = attempt.httpReq.WithContext(ctx)
|
||||||
|
return attempt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.httpReq == nil {
|
||||||
|
return nil, fmt.Errorf("http request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := r.httpReq.Clone(ctx)
|
||||||
|
if r.httpReq.GetBody != nil {
|
||||||
|
body, err := r.httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "get raw request body")
|
||||||
|
}
|
||||||
|
raw.Body = body
|
||||||
|
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
|
||||||
|
return nil, fmt.Errorf("raw request body is not replayable")
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.httpReq = raw
|
||||||
|
return attempt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) canRetryRequest(r *Request) bool {
|
||||||
|
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isReplayableRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdempotentMethod(method string) bool {
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReplayableRequest(r *Request) bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.doRaw {
|
||||||
|
if r.httpReq == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.httpReq.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reader / stream body 通常不可重放,保守地不重试。
|
||||||
|
if r.config.Body.Reader != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range r.config.Body.Files {
|
||||||
|
if f.FileData != nil || f.FilePath == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
||||||
|
if attempt >= maxAttempts-1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ctx != nil && ctx.Err() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if p.onError != nil {
|
||||||
|
return p.onError(err)
|
||||||
|
}
|
||||||
|
return isRetryableError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil || resp.Response == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := p.statuses[resp.StatusCode]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetryableError(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
if netErr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if netErr.Temporary() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
|
||||||
|
if p.baseDelay <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
|
||||||
|
if p.maxDelay > 0 && delay > p.maxDelay {
|
||||||
|
delay = p.maxDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.jitter <= 0 {
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
low := 1 - p.jitter
|
||||||
|
if low < 0 {
|
||||||
|
low = 0
|
||||||
|
}
|
||||||
|
high := 1 + p.jitter
|
||||||
|
scale := low + rand.Float64()*(high-low)
|
||||||
|
return time.Duration(float64(delay) * scale)
|
||||||
|
}
|
||||||
298
retry_test.go
Normal file
298
retry_test.go
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithRetrySmokeGet(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Get(s.URL,
|
||||||
|
WithRetry(2,
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 3 {
|
||||||
|
t.Fatalf("hits=%d want=3", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryResponseRequestPointerStable(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.Request() != req {
|
||||||
|
t.Fatal("response request pointer should point to original request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyReader(strings.NewReader("payload")).
|
||||||
|
SetRetry(3,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 1 {
|
||||||
|
t.Fatalf("hits=%d want=1", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello").
|
||||||
|
SetRetry(1,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 2 {
|
||||||
|
t.Fatalf("hits=%d want=2", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw")))
|
||||||
|
if rawReq.GetBody != nil {
|
||||||
|
t.Fatal("raw request GetBody should be nil in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)).
|
||||||
|
SetRetry(2,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&hits) != 1 {
|
||||||
|
t.Fatalf("hits=%d want=1", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
time.Sleep(80 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetTimeout(120*time.Millisecond).
|
||||||
|
SetRetry(3,
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("expected context deadline exceeded, got: %v", err)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h > 2 {
|
||||||
|
t.Fatalf("hits=%d want<=2 under tight timeout budget", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryInvalidMax(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error for negative retry max")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetrySeriesSmokeGet(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetRetry(2).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
SetRetryStatuses(http.StatusTooManyRequests)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h != 3 {
|
||||||
|
t.Fatalf("hits=%d want=3", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryIdempotentOnlyWithPost(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello").
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryIdempotentOnly(false).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h != 2 {
|
||||||
|
t.Fatalf("hits=%d want=2", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryOnErrorOverridesDefault(t *testing.T) {
|
||||||
|
var dials int32
|
||||||
|
dialErr := errors.New("dial failed")
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).
|
||||||
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
atomic.AddInt32(&dials, 1)
|
||||||
|
return nil, dialErr
|
||||||
|
}).
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
SetRetryOnError(func(err error) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&dials); h != 2 {
|
||||||
|
t.Fatalf("dial attempts=%d want=2", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryOptionRequireEnableRetry(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error when setting retry options before SetRetry")
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.Err().Error(), "call SetRetry first") {
|
||||||
|
t.Fatalf("unexpected error: %v", req.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
115
timeout_refactor_test.go
Normal file
115
timeout_refactor_test.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
baseTimeout := client.HTTPClient().Timeout
|
||||||
|
|
||||||
|
_, err := client.Get(s.URL, WithTimeout(80*time.Millisecond))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected request timeout error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.HTTPClient().Timeout != baseTimeout {
|
||||||
|
t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second request should succeed with larger timeout: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first request should succeed: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp1.Close()
|
||||||
|
|
||||||
|
// 如果请求超时依赖连接级绝对 deadline,经过该等待后复用连接会出现误超时。
|
||||||
|
time.Sleep(220 * time.Millisecond)
|
||||||
|
|
||||||
|
resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second request should not be affected by previous timeout window: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp2.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen error: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
time.Sleep(180 * time.Millisecond)
|
||||||
|
_, _ = conn.Write([]byte("x"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial error: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil {
|
||||||
|
t.Fatalf("set read deadline error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, err = c.Read(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected read timeout error")
|
||||||
|
}
|
||||||
|
if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
|
||||||
|
t.Fatalf("expected net timeout error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
t.Fatalf("clear read deadline error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(c, buf); err != nil {
|
||||||
|
t.Fatalf("read after clearing deadline should succeed: %v", err)
|
||||||
|
}
|
||||||
|
if string(buf) != "x" {
|
||||||
|
t.Fatalf("unexpected payload: %q", string(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
}
|
||||||
2
types.go
2
types.go
@ -84,6 +84,7 @@ type RequestConfig struct {
|
|||||||
BasicAuth [2]string // Basic 认证
|
BasicAuth [2]string // Basic 认证
|
||||||
ContentLength int64 // 手动设置的 Content-Length
|
ContentLength int64 // 手动设置的 Content-Length
|
||||||
AutoCalcContentLength bool // 自动计算 Content-Length
|
AutoCalcContentLength bool // 自动计算 Content-Length
|
||||||
|
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
|
||||||
UploadProgress UploadProgressFunc // 上传进度回调
|
UploadProgress UploadProgressFunc // 上传进度回调
|
||||||
|
|
||||||
// Transport 配置
|
// Transport 配置
|
||||||
@ -121,6 +122,7 @@ func (c *RequestConfig) Clone() *RequestConfig {
|
|||||||
BasicAuth: c.BasicAuth,
|
BasicAuth: c.BasicAuth,
|
||||||
ContentLength: c.ContentLength,
|
ContentLength: c.ContentLength,
|
||||||
AutoCalcContentLength: c.AutoCalcContentLength,
|
AutoCalcContentLength: c.AutoCalcContentLength,
|
||||||
|
MaxRespBodyBytes: c.MaxRespBodyBytes,
|
||||||
UploadProgress: c.UploadProgress,
|
UploadProgress: c.UploadProgress,
|
||||||
CustomTransport: c.CustomTransport,
|
CustomTransport: c.CustomTransport,
|
||||||
Transport: c.Transport, // Transport 共享
|
Transport: c.Transport, // Transport 共享
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user