Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
b026953c74
|
|||
|
2f4c7158cf
|
|||
|
732e81316c
|
|||
|
9ac9b65bc5
|
|||
|
b5bd7595a1
|
|||
|
4568e17f06
|
|||
|
1bb30514ec
|
|||
|
50aef48d49
|
|||
|
0e2f91eee2
|
|||
|
b90c59d6e7
|
|||
|
4e154cc17b
|
|||
|
67b0025f9c
|
|||
|
c4fa62536a
|
|||
|
260ceb90ed
|
|||
|
d260181adf
|
|||
|
e3b7369e12
|
|||
|
4e17fee681
|
|||
|
a8eed30db5
|
|||
| c1eaf43058 | |||
| 9f5aca124d | |||
| 54958724e7 | |||
| 7a17672149 | |||
| 44b807d3d1 | |||
| 0d847462b3 | |||
| deed4207ea | |||
| f6363fed07 | |||
| 1de78f2f06 | |||
| d0122a9771 |
@@ -0,0 +1,6 @@
|
|||||||
|
.idea
|
||||||
|
.sentrux/
|
||||||
|
agent_readme.md
|
||||||
|
target.md
|
||||||
|
agents.md
|
||||||
|
.codex
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2026 starnet contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# starnet
|
||||||
|
|
||||||
|
`starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。
|
||||||
|
|
||||||
|
## 功能概览
|
||||||
|
|
||||||
|
- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
|
||||||
|
- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
|
||||||
|
- 支持请求级 `Host` 覆盖、显式 `TLSServerName/SNI` 控制,以及结构化 trace 回调 / 摘要
|
||||||
|
- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
|
||||||
|
- 响应体大小限制,避免一次性读取过大内容
|
||||||
|
- 错误分类辅助:`ClassifyError`、`IsTimeout`、`IsDNS`、`IsTLS`、`IsProxy`、`IsCanceled`
|
||||||
|
- TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景
|
||||||
|
- ICMP Ping,支持 IPv4 / IPv6 目标和选项化探测
|
||||||
|
|
||||||
|
## 主要能力
|
||||||
|
|
||||||
|
### HTTP 客户端与请求构建
|
||||||
|
|
||||||
|
- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
|
||||||
|
- 支持 `Get`、`Post`、`Put`、`Delete`、`Head`、`Patch`、`Options`、`Trace`、`Connect`
|
||||||
|
- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
|
||||||
|
- 支持显式 `Host` 覆盖与 `TLSServerName` 设置,便于直连 IP、虚拟主机和证书校验场景分离控制
|
||||||
|
- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
|
||||||
|
- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
|
||||||
|
|
||||||
|
### Trace 与诊断
|
||||||
|
|
||||||
|
- 支持 `TraceHooks`,可接收 DNS、建连、TLS 握手、写请求、首包等结构化事件
|
||||||
|
- 支持 `TraceRecorder` / `TraceSummary`,用于汇总一次请求的关键网络过程和 TLS 摘要
|
||||||
|
- `Request.TraceSummary()` 返回该请求最近一次执行的摘要快照,`Response.TraceSummary()` 返回当前响应对应的摘要快照
|
||||||
|
- 若多个请求共享同一个 `TraceRecorder`,其 `Summary()` 表示最近一次完成请求的摘要
|
||||||
|
|
||||||
|
### 超时与重试
|
||||||
|
|
||||||
|
- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
|
||||||
|
- 重试支持:
|
||||||
|
- 最大尝试次数
|
||||||
|
- 基础退避、最大退避和退避因子
|
||||||
|
- 抖动比例
|
||||||
|
- 可重试状态码集合
|
||||||
|
- 仅幂等方法重试
|
||||||
|
- 自定义错误判定函数
|
||||||
|
- 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用
|
||||||
|
|
||||||
|
### 响应处理
|
||||||
|
|
||||||
|
- 提供 `Bytes`、`String`、`JSON`、`Reader` 等响应体读取接口
|
||||||
|
- 支持自动预取响应体
|
||||||
|
- 支持按字节数限制响应体读取上限
|
||||||
|
|
||||||
|
### Ping 模块
|
||||||
|
|
||||||
|
- 提供 `Ping`、`PingWithContext`、`Pingable` 以及兼容函数 `IsIpPingable`
|
||||||
|
- `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数
|
||||||
|
- 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get b612.me/starnet
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速示例
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
resp, err := starnet.Get(
|
||||||
|
"https://example.com",
|
||||||
|
starnet.WithTimeout(2*time.Second),
|
||||||
|
starnet.WithRetry(2,
|
||||||
|
starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2),
|
||||||
|
starnet.WithRetryJitter(0.1),
|
||||||
|
),
|
||||||
|
starnet.WithMaxRespBodyBytes(1<<20),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("request failed:", starnet.ClassifyError(err), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println("status:", resp.StatusCode)
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
|
||||||
|
ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{
|
||||||
|
Count: 2,
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
})
|
||||||
|
fmt.Println("pingable:", ok, pingErr == nil)
|
||||||
|
|
||||||
|
_ = http.MethodGet
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 行为说明
|
||||||
|
|
||||||
|
- `NewClient`、`NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
|
||||||
|
- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`。
|
||||||
|
- `SetHost` / `WithHost` 只覆盖 HTTP 请求的 `Host`;如需单独控制 TLS SNI 或证书校验名,应配合 `SetTLSServerName` / `WithTLSServerName` 使用。
|
||||||
|
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader` 或 `WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
|
||||||
|
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
|
||||||
|
- 绑定到请求上的 `TraceRecorder` 用于发布已完成请求的摘要;请求执行中的中间状态不保证通过共享 recorder 实时可见。
|
||||||
|
|
||||||
|
## 稳定性说明
|
||||||
|
|
||||||
|
- 原始 ICMP Ping 在部分系统上需要额外权限。
|
||||||
|
- 依赖外部网络环境的集成测试结果可能受运行环境影响。
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。
|
||||||
+1657
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,224 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkGetRequest(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body().String()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetRequestWithHeaders(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL,
|
||||||
|
WithHeader("X-Custom", "value"),
|
||||||
|
WithUserAgent("BenchmarkAgent"))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body().String()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPostRequest(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := []byte("test data for benchmark")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Post(server.URL, WithBody(testData))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Post() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body().String()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkJSONRequest(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
type TestData struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Value int `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
data := TestData{Name: "test", Value: 123}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Post(server.URL, WithJSON(data))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Post() error: %v", err)
|
||||||
|
}
|
||||||
|
var result map[string]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkConcurrentRequests(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body().String()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestClone(b *testing.B) {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET").
|
||||||
|
SetHeader("X-Custom", "value").
|
||||||
|
AddQuery("key", "value")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = req.Clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkClientCreation(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = NewClientNoErr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestCreation(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = NewSimpleRequest("https://example.com", "GET")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestPrepareDefaultPath(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET")
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
b.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestPrepareDynamicPath(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET",
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
b.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkResponseBodyRead(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("test response data"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Pre-fetch response
|
||||||
|
resp, _ := Get(server.URL, WithAutoFetch(true))
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = resp.Body().String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDifferentResponseSizes(b *testing.B) {
|
||||||
|
sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
responseData := make([]byte, size)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
responseData[i] = 'A'
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(responseData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
+145
@@ -0,0 +1,145 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestBodyBytes(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
w.Write(body)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := []byte("test data")
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").SetBody(testData)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().Bytes()
|
||||||
|
if !bytes.Equal(body, testData) {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBodyString(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
w.Write(body)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := "test string data"
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != testData {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBodyReader(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
w.Write(body)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := "test reader data"
|
||||||
|
reader := strings.NewReader(testData)
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != testData {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestJSON(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("Content-Type") != ContentTypeJSON {
|
||||||
|
t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]string
|
||||||
|
json.NewDecoder(r.Body).Decode(&data)
|
||||||
|
json.NewEncoder(w).Encode(data)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
testData := map[string]string{
|
||||||
|
"name": "John",
|
||||||
|
"email": "john@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").SetJSON(testData)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
|
||||||
|
if result["name"] != testData["name"] {
|
||||||
|
t.Errorf("name = %v; want %v", result["name"], testData["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestFormData(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseForm()
|
||||||
|
data := make(map[string]string)
|
||||||
|
for k, v := range r.Form {
|
||||||
|
if len(v) > 0 {
|
||||||
|
data[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(data)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").
|
||||||
|
AddFormData("name", "John").
|
||||||
|
AddFormData("email", "john@example.com")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
|
||||||
|
if result["name"] != "John" {
|
||||||
|
t.Errorf("name = %v; want John", result["name"])
|
||||||
|
}
|
||||||
|
if result["email"] != "john@example.com" {
|
||||||
|
t.Errorf("email = %v; want john@example.com", result["email"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,349 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client HTTP 客户端封装
|
||||||
|
type Client struct {
|
||||||
|
client *http.Client
|
||||||
|
opts []RequestOpt
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 创建新的 Client
|
||||||
|
func NewClient(opts ...RequestOpt) (*Client, error) {
|
||||||
|
// 创建基础 Transport
|
||||||
|
baseTransport := newBaseHTTPTransport()
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &Transport{base: baseTransport},
|
||||||
|
//Timeout: DefaultTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用选项(如果有)
|
||||||
|
if len(opts) > 0 {
|
||||||
|
// 创建临时请求以应用选项
|
||||||
|
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "create client")
|
||||||
|
}
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, wrapError(req.err, "create client")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// 如果选项中有自定义配置,应用到 httpClient
|
||||||
|
if req.config.Network.Timeout > 0 {
|
||||||
|
httpClient.Timeout = req.config.Network.Timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 如果有自定义 Transport
|
||||||
|
if req.config.CustomTransport && req.config.Transport != nil {
|
||||||
|
httpClient.Transport = &Transport{base: req.config.Transport}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
client: httpClient,
|
||||||
|
opts: opts,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientNoErr 创建新的 Client(忽略错误)。
|
||||||
|
// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象;
|
||||||
|
// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。
|
||||||
|
func NewClientNoErr(opts ...RequestOpt) *Client {
|
||||||
|
client, _ := NewClient(opts...)
|
||||||
|
if client == nil {
|
||||||
|
client = &Client{
|
||||||
|
client: &http.Client{},
|
||||||
|
opts: opts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientFromHTTP 从 http.Client 创建 Client
|
||||||
|
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
|
||||||
|
if httpClient == nil {
|
||||||
|
return nil, ErrNilClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 Transport 是我们的自定义类型
|
||||||
|
if httpClient.Transport == nil {
|
||||||
|
httpClient.Transport = &Transport{
|
||||||
|
base: &http.Transport{},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch t := httpClient.Transport.(type) {
|
||||||
|
case *Transport:
|
||||||
|
// 已经是我们的类型
|
||||||
|
if t.base == nil {
|
||||||
|
t.base = &http.Transport{}
|
||||||
|
}
|
||||||
|
case *http.Transport:
|
||||||
|
// 包装标准 Transport
|
||||||
|
httpClient.Transport = &Transport{
|
||||||
|
base: t,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported transport type: %T", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
client: httpClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient 获取底层 http.Client
|
||||||
|
func (c *Client) HTTPClient() *http.Client {
|
||||||
|
return c.client
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestOptions 获取默认选项(返回副本)
|
||||||
|
func (c *Client) RequestOptions() []RequestOpt {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
opts := make([]RequestOpt, len(c.opts))
|
||||||
|
copy(opts, c.opts)
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOptions 设置默认选项
|
||||||
|
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.opts = opts
|
||||||
|
c.mu.Unlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddOptions 追加默认选项
|
||||||
|
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.opts = append(c.opts, opts...)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone 克隆 Client(深拷贝)
|
||||||
|
func (c *Client) Clone() *Client {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
// 克隆 Transport
|
||||||
|
var transport http.RoundTripper
|
||||||
|
if c.client.Transport != nil {
|
||||||
|
switch t := c.client.Transport.(type) {
|
||||||
|
case *Transport:
|
||||||
|
transport = &Transport{
|
||||||
|
base: t.base.Clone(),
|
||||||
|
}
|
||||||
|
case *http.Transport:
|
||||||
|
transport = t.Clone()
|
||||||
|
default:
|
||||||
|
transport = c.client.Transport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
client: &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
CheckRedirect: c.client.CheckRedirect,
|
||||||
|
Jar: c.client.Jar,
|
||||||
|
Timeout: c.client.Timeout,
|
||||||
|
},
|
||||||
|
opts: append([]RequestOpt(nil), c.opts...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultTLSConfig 设置默认 TLS 配置
|
||||||
|
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
|
||||||
|
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||||
|
transport.mu.Lock()
|
||||||
|
transport.ensureBaseLocked()
|
||||||
|
if tlsConfig != nil {
|
||||||
|
transport.base.TLSClientConfig = tlsConfig.Clone()
|
||||||
|
} else {
|
||||||
|
transport.base.TLSClientConfig = nil
|
||||||
|
}
|
||||||
|
transport.resetDynamicTransportCacheLocked()
|
||||||
|
transport.mu.Unlock()
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
|
||||||
|
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
|
||||||
|
if transport, ok := c.client.Transport.(*Transport); ok {
|
||||||
|
transport.mu.Lock()
|
||||||
|
transport.ensureBaseLocked()
|
||||||
|
if transport.base.TLSClientConfig == nil {
|
||||||
|
transport.base.TLSClientConfig = &tls.Config{}
|
||||||
|
} else {
|
||||||
|
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
|
||||||
|
}
|
||||||
|
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||||
|
transport.resetDynamicTransportCacheLocked()
|
||||||
|
transport.mu.Unlock()
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableRedirect 禁用重定向
|
||||||
|
func (c *Client) DisableRedirect() *Client {
|
||||||
|
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableRedirect 启用重定向
|
||||||
|
func (c *Client) EnableRedirect() *Client {
|
||||||
|
c.client.CheckRedirect = nil
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequest 创建新请求
|
||||||
|
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
return c.NewRequestWithContext(context.Background(), url, method, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequestWithContext 创建新请求(带 context)
|
||||||
|
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
// 合并 Client 级别和请求级别的选项
|
||||||
|
c.mu.RLock()
|
||||||
|
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
req, err := newRequest(ctx, url, method, allOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, req.err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.client = c
|
||||||
|
req.httpClient = c.client
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 发送 GET 请求
|
||||||
|
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodGet, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post 发送 POST 请求
|
||||||
|
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodPost, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 发送 PUT 请求
|
||||||
|
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodPut, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 发送 DELETE 请求
|
||||||
|
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodDelete, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head 发送 HEAD 请求
|
||||||
|
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodHead, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch 发送 PATCH 请求
|
||||||
|
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodPatch, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options 发送 OPTIONS 请求
|
||||||
|
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodOptions, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
|
||||||
|
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||||
|
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
||||||
|
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||||
|
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
|
||||||
|
if err != nil {
|
||||||
|
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
|
||||||
|
return &Request{
|
||||||
|
ctx: ctx,
|
||||||
|
url: url,
|
||||||
|
method: method,
|
||||||
|
err: err,
|
||||||
|
config: &RequestConfig{
|
||||||
|
Headers: make(http.Header),
|
||||||
|
Queries: make(map[string][]string),
|
||||||
|
Body: BodyConfig{
|
||||||
|
FormData: make(map[string][]string),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
client: c,
|
||||||
|
httpClient: c.client,
|
||||||
|
autoFetch: DefaultFetchRespBody,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace 发送 TRACE 请求
|
||||||
|
func (c *Client) Trace(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodTrace, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 发送 CONNECT 请求
|
||||||
|
func (c *Client) Connect(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := c.NewRequest(url, http.MethodConnect, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
+223
@@ -0,0 +1,223 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewClient(t *testing.T) {
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClient() error: %v", err)
|
||||||
|
}
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("NewClient() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClientNoErr(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("NewClientNoErr() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClientFromHTTP(t *testing.T) {
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := NewClientFromHTTP(httpClient)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClientFromHTTP() error: %v", err)
|
||||||
|
}
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("NewClientFromHTTP() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with nil client
|
||||||
|
_, err = NewClientFromHTTP(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("NewClientFromHTTP(nil) should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientOptions(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
// Set options
|
||||||
|
client.SetOptions(WithTimeout(5 * time.Second))
|
||||||
|
opts := client.RequestOptions()
|
||||||
|
if len(opts) != 1 {
|
||||||
|
t.Errorf("RequestOptions() length = %v; want 1", len(opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add options
|
||||||
|
client.AddOptions(WithUserAgent("TestAgent"))
|
||||||
|
opts = client.RequestOptions()
|
||||||
|
if len(opts) != 2 {
|
||||||
|
t.Errorf("RequestOptions() length = %v; want 2", len(opts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientClone(t *testing.T) {
|
||||||
|
client := NewClientNoErr(WithTimeout(5 * time.Second))
|
||||||
|
cloned := client.Clone()
|
||||||
|
|
||||||
|
if cloned == nil {
|
||||||
|
t.Fatal("Clone() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 修改克隆的 client
|
||||||
|
cloned.SetOptions(WithTimeout(10 * time.Second))
|
||||||
|
|
||||||
|
origOpts := client.RequestOptions()
|
||||||
|
clonedOpts := cloned.RequestOptions()
|
||||||
|
|
||||||
|
// 原 client 应该还是 1 个选项
|
||||||
|
if len(origOpts) != 1 {
|
||||||
|
t.Errorf("Original client options = %v; want 1", len(origOpts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖)
|
||||||
|
if len(clonedOpts) != 1 {
|
||||||
|
t.Errorf("Cloned client options = %v; want 1", len(clonedOpts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientHTTPMethods(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(r.Method))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method func(string, ...RequestOpt) (*Response, error)
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"GET", client.Get, "GET"},
|
||||||
|
{"POST", client.Post, "POST"},
|
||||||
|
{"PUT", client.Put, "PUT"},
|
||||||
|
{"DELETE", client.Delete, "DELETE"},
|
||||||
|
{"PATCH", client.Patch, "PATCH"},
|
||||||
|
{"HEAD", client.Head, "HEAD"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, err := tt.method(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%s() error: %v", tt.name, err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if tt.want != "HEAD" {
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != tt.want {
|
||||||
|
t.Errorf("Body = %v; want %v", body, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientRedirect(t *testing.T) {
|
||||||
|
redirectCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if redirectCount < 2 {
|
||||||
|
redirectCount++
|
||||||
|
http.Redirect(w, r, "/redirected", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("final"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Test with redirect enabled (default)
|
||||||
|
client := NewClientNoErr()
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
|
||||||
|
if redirectCount != 2 {
|
||||||
|
t.Errorf("Redirect count = %v; want 2", redirectCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with redirect disabled
|
||||||
|
redirectCount = 0
|
||||||
|
client.DisableRedirect()
|
||||||
|
resp2, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
if resp2.StatusCode != http.StatusFound {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientTLSConfig(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Without skip verify (should fail with self-signed cert)
|
||||||
|
client := NewClientNoErr()
|
||||||
|
_, err := client.Get(server.URL)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected TLS error with self-signed cert, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// With skip verify
|
||||||
|
client.SetDefaultSkipTLSVerify(true)
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() with skip verify error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientNewSimpleRequest(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v"))
|
||||||
|
if req == nil {
|
||||||
|
t.Fatal("NewSimpleRequest returned nil")
|
||||||
|
}
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("NewSimpleRequest err: %v", req.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Body = %v; want OK", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConcurrentRequests(t *testing.T) {
|
||||||
|
var counter int64
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt64(&counter, 1)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
concurrency := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(concurrency)
|
||||||
|
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get() error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if atomic.LoadInt64(&counter) != int64(concurrency) {
|
||||||
|
t.Errorf("counter = %v; want %v", counter, concurrency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentClientModification(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(200)
|
||||||
|
|
||||||
|
// 100 goroutines reading
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Get() error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 100 goroutines modifying options
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
if i%2 == 0 {
|
||||||
|
client.AddOptions(WithTimeout(5 * time.Second))
|
||||||
|
} else {
|
||||||
|
_ = client.RequestOptions()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRequestClone(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value")
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(50)
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
cloned := baseReq.Clone()
|
||||||
|
// 修复:使用有效的 header 值
|
||||||
|
cloned.SetHeader("X-Index", fmt.Sprintf("%d", i))
|
||||||
|
resp, err := cloned.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Do() error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
+188
@@ -0,0 +1,188 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextKey 私有的 context key 类型(防止冲突)
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ctxKeyTransport contextKey = iota
|
||||||
|
ctxKeyTLSConfig
|
||||||
|
ctxKeyTLSConfigCacheable
|
||||||
|
ctxKeyTLSServerName
|
||||||
|
ctxKeyProxy
|
||||||
|
ctxKeyCustomIP
|
||||||
|
ctxKeyCustomDNS
|
||||||
|
ctxKeyDialTimeout
|
||||||
|
ctxKeyTimeout
|
||||||
|
ctxKeyLookupIP
|
||||||
|
ctxKeyDialFunc
|
||||||
|
ctxKeyRequestContext
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestContext 从 context 中提取的请求配置
|
||||||
|
type RequestContext struct {
|
||||||
|
Transport *http.Transport
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
TLSConfigCacheable bool
|
||||||
|
TLSServerName string
|
||||||
|
Proxy string
|
||||||
|
CustomIP []string
|
||||||
|
CustomDNS []string
|
||||||
|
DialTimeout time.Duration
|
||||||
|
Timeout time.Duration
|
||||||
|
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
|
||||||
|
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyRequestContext = &RequestContext{}
|
||||||
|
|
||||||
|
// getRequestContext 从 context 中提取请求配置
|
||||||
|
func getRequestContext(ctx context.Context) *RequestContext {
|
||||||
|
if v := ctx.Value(ctxKeyRequestContext); v != nil {
|
||||||
|
if rc, ok := v.(*RequestContext); ok && rc != nil {
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rc *RequestContext
|
||||||
|
ensure := func() *RequestContext {
|
||||||
|
if rc == nil {
|
||||||
|
rc = &RequestContext{}
|
||||||
|
}
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTransport); v != nil {
|
||||||
|
ensure().Transport, _ = v.(*http.Transport)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
|
||||||
|
ensure().TLSConfig, _ = v.(*tls.Config)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
|
||||||
|
ensure().TLSConfigCacheable, _ = v.(bool)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
|
||||||
|
ensure().TLSServerName, _ = v.(string)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyProxy); v != nil {
|
||||||
|
ensure().Proxy, _ = v.(string)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyCustomIP); v != nil {
|
||||||
|
ensure().CustomIP, _ = v.([]string)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
|
||||||
|
ensure().CustomDNS, _ = v.([]string)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
|
||||||
|
ensure().DialTimeout, _ = v.(time.Duration)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyTimeout); v != nil {
|
||||||
|
ensure().Timeout, _ = v.(time.Duration)
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyLookupIP); v != nil {
|
||||||
|
ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
|
||||||
|
}
|
||||||
|
if v := ctx.Value(ctxKeyDialFunc); v != nil {
|
||||||
|
ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
|
||||||
|
}
|
||||||
|
if rc == nil {
|
||||||
|
return emptyRequestContext
|
||||||
|
}
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRequestContext(rc *RequestContext) *RequestContext {
|
||||||
|
if rc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *rc
|
||||||
|
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
|
||||||
|
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
// needsDynamicTransport 判断是否需要动态 Transport
|
||||||
|
func needsDynamicTransport(rc *RequestContext) bool {
|
||||||
|
if rc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rc.Transport != nil ||
|
||||||
|
rc.TLSConfig != nil ||
|
||||||
|
rc.Proxy != "" ||
|
||||||
|
rc.DialFn != nil ||
|
||||||
|
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
|
||||||
|
len(rc.CustomIP) > 0 ||
|
||||||
|
len(rc.CustomDNS) > 0 ||
|
||||||
|
rc.LookupIPFn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := &RequestContext{
|
||||||
|
DialTimeout: config.Network.DialTimeout,
|
||||||
|
Timeout: config.Network.Timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 TLS 配置
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
tlsConfigCacheable := false
|
||||||
|
|
||||||
|
if config.TLS.Config != nil {
|
||||||
|
tlsConfig = config.TLS.Config.Clone()
|
||||||
|
} else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
}
|
||||||
|
tlsConfigCacheable = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TLS.SkipVerify && tlsConfig != nil {
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
if config.TLS.ServerName != "" && tlsConfig != nil {
|
||||||
|
tlsConfig.ServerName = config.TLS.ServerName
|
||||||
|
}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
rc.TLSConfig = tlsConfig
|
||||||
|
rc.TLSConfigCacheable = tlsConfigCacheable
|
||||||
|
}
|
||||||
|
if config.TLS.ServerName != "" {
|
||||||
|
rc.TLSServerName = config.TLS.ServerName
|
||||||
|
} else if defaultTLSServerName != "" {
|
||||||
|
rc.TLSServerName = defaultTLSServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
rc.Proxy = config.Network.Proxy
|
||||||
|
rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
|
||||||
|
rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
|
||||||
|
rc.LookupIPFn = config.DNS.LookupFunc
|
||||||
|
rc.DialFn = config.Network.DialFunc
|
||||||
|
|
||||||
|
if config.CustomTransport && config.Transport != nil {
|
||||||
|
rc.Transport = config.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
if !needsDynamicTransport(rc) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return rc
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectRequestConfig 将请求配置注入到 context
|
||||||
|
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
|
||||||
|
rc := buildRequestContext(config, defaultTLSServerName)
|
||||||
|
if rc == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, ctxKeyRequestContext, rc)
|
||||||
|
}
|
||||||
@@ -1,463 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stario"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
|
|
||||||
HEADER_FORM_DATA = `multipart/form-data`
|
|
||||||
HEADER_JSON = `application/json`
|
|
||||||
HEADER_PLAIN = `text/plain`
|
|
||||||
)
|
|
||||||
|
|
||||||
type RequestFile struct {
|
|
||||||
UploadFile string
|
|
||||||
UploadForm map[string]string
|
|
||||||
UploadName string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
Url string
|
|
||||||
RespURL string
|
|
||||||
Method string
|
|
||||||
RecvData []byte
|
|
||||||
RecvContentLength int64
|
|
||||||
RecvIo io.Writer
|
|
||||||
RespHeader http.Header
|
|
||||||
RespCookies []*http.Cookie
|
|
||||||
RespHttpCode int
|
|
||||||
Location *url.URL
|
|
||||||
CircleBuffer *stario.StarBuffer
|
|
||||||
respReader io.ReadCloser
|
|
||||||
respOrigin *http.Response
|
|
||||||
reqOrigin *http.Request
|
|
||||||
RequestOpts
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestOpts struct {
|
|
||||||
RequestFile
|
|
||||||
PostBuffer io.Reader
|
|
||||||
Process func(float64)
|
|
||||||
Proxy string
|
|
||||||
Timeout time.Duration
|
|
||||||
DialTimeout time.Duration
|
|
||||||
ReqHeader http.Header
|
|
||||||
ReqCookies []*http.Cookie
|
|
||||||
WriteRecvData bool
|
|
||||||
SkipTLSVerify bool
|
|
||||||
CustomTransport *http.Transport
|
|
||||||
Queries map[string]string
|
|
||||||
DisableRedirect bool
|
|
||||||
TlsConfig *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestOpt func(opt *RequestOpts)
|
|
||||||
|
|
||||||
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.DialTimeout = timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.Timeout = timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHeader(key, val string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqHeader.Set(key, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithTlsConfig(tlscfg *tls.Config) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.TlsConfig = tlscfg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHeaderMap(header map[string]string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
for key, val := range header {
|
|
||||||
opt.ReqHeader.Set(key, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHeaderAdd(key, val string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqHeader.Add(key, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithReader(r io.Reader) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.PostBuffer = r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithFetchRespBody(fetch bool) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.WriteRecvData = fetch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithCookies(ck []*http.Cookie) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqCookies = ck
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithCookie(key, val, path string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithCookieMap(header map[string]string, path string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
for key, val := range header {
|
|
||||||
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithQueries(queries map[string]string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.Queries = queries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithProxy(proxy string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.Proxy = proxy
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithProcess(fn func(float64)) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.Process = fn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithContentType(ct string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqHeader.Set("Content-Type", ct)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithUserAgent(ua string) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.ReqHeader.Set("User-Agent", ua)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithCustomTransport(hs *http.Transport) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.CustomTransport = hs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithSkipTLSVerify(skip bool) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.SkipTLSVerify = skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithDisableRedirect(disable bool) RequestOpt {
|
|
||||||
return func(opt *RequestOpts) {
|
|
||||||
opt.DisableRedirect = disable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRequests(url string, rawdata []byte, method string, opts ...RequestOpt) Request {
|
|
||||||
req := Request{
|
|
||||||
RequestOpts: RequestOpts{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
DialTimeout: 15 * time.Second,
|
|
||||||
WriteRecvData: true,
|
|
||||||
},
|
|
||||||
Url: url,
|
|
||||||
Method: method,
|
|
||||||
}
|
|
||||||
if rawdata != nil {
|
|
||||||
req.PostBuffer = bytes.NewBuffer(rawdata)
|
|
||||||
}
|
|
||||||
req.ReqHeader = make(http.Header)
|
|
||||||
if strings.ToUpper(method) == "POST" {
|
|
||||||
req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE)
|
|
||||||
}
|
|
||||||
req.ReqHeader.Set("User-Agent", "B612 / 1.1.0")
|
|
||||||
for _, v := range opts {
|
|
||||||
v(&req.RequestOpts)
|
|
||||||
}
|
|
||||||
if req.CustomTransport == nil {
|
|
||||||
req.CustomTransport = &http.Transport{}
|
|
||||||
}
|
|
||||||
if req.SkipTLSVerify {
|
|
||||||
if req.CustomTransport.TLSClientConfig == nil {
|
|
||||||
req.CustomTransport.TLSClientConfig = &tls.Config{}
|
|
||||||
}
|
|
||||||
req.CustomTransport.TLSClientConfig.InsecureSkipVerify = true
|
|
||||||
}
|
|
||||||
if req.TlsConfig != nil {
|
|
||||||
req.CustomTransport.TLSClientConfig = req.TlsConfig
|
|
||||||
}
|
|
||||||
req.CustomTransport.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
|
||||||
c, err := net.DialTimeout(netw, addr, req.DialTimeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if req.Timeout != 0 {
|
|
||||||
c.SetDeadline(time.Now().Add(req.Timeout))
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
func (curl *Request) ResetReqHeader() {
|
|
||||||
curl.ReqHeader = make(http.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (curl *Request) ResetReqCookies() {
|
|
||||||
curl.ReqCookies = []*http.Cookie{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (curl *Request) AddSimpleCookie(key, value string) {
|
|
||||||
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"})
|
|
||||||
}
|
|
||||||
func (curl *Request) AddCookie(key, value, path string) {
|
|
||||||
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: path})
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomBoundary() string {
|
|
||||||
var buf [30]byte
|
|
||||||
_, err := io.ReadFull(rand.Reader, buf[:])
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%x", buf[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func Curl(curl Request) (resps Request, err error) {
|
|
||||||
var fpsrc *os.File
|
|
||||||
if curl.RequestFile.UploadFile != "" {
|
|
||||||
fpsrc, err = os.Open(curl.UploadFile)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer fpsrc.Close()
|
|
||||||
boundary := randomBoundary()
|
|
||||||
boundarybytes := []byte("\r\n--" + boundary + "\r\n")
|
|
||||||
endbytes := []byte("\r\n--" + boundary + "--\r\n")
|
|
||||||
fpstat, _ := fpsrc.Stat()
|
|
||||||
filebig := float64(fpstat.Size())
|
|
||||||
sum, n := 0, 0
|
|
||||||
fpdst := stario.NewStarBuffer(1048576)
|
|
||||||
if curl.UploadForm != nil {
|
|
||||||
for k, v := range curl.UploadForm {
|
|
||||||
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\";\r\nContent-Type: x-www-form-urlencoded \r\n\r\n", k)
|
|
||||||
fpdst.Write(boundarybytes)
|
|
||||||
fpdst.Write([]byte(header))
|
|
||||||
fpdst.Write([]byte(v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n", curl.UploadName, fpstat.Name())
|
|
||||||
fpdst.Write(boundarybytes)
|
|
||||||
fpdst.Write([]byte(header))
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
bufs := make([]byte, 393213)
|
|
||||||
n, err = fpsrc.Read(bufs)
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
if n != 0 {
|
|
||||||
fpdst.Write(bufs[0:n])
|
|
||||||
if curl.Process != nil {
|
|
||||||
go curl.Process(float64(sum+n) / filebig * 100)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sum += n
|
|
||||||
if curl.Process != nil {
|
|
||||||
go curl.Process(float64(sum+n) / filebig * 100)
|
|
||||||
}
|
|
||||||
fpdst.Write(bufs[0:n])
|
|
||||||
}
|
|
||||||
fpdst.Write(endbytes)
|
|
||||||
fpdst.Write(nil)
|
|
||||||
}()
|
|
||||||
curl.CircleBuffer = fpdst
|
|
||||||
curl.ReqHeader.Set("Content-Type", "multipart/form-data;boundary="+boundary)
|
|
||||||
}
|
|
||||||
req, resp, err := netcurl(curl)
|
|
||||||
if err != nil {
|
|
||||||
return Request{}, err
|
|
||||||
}
|
|
||||||
if resp.Request != nil && resp.Request.URL != nil {
|
|
||||||
curl.RespURL = resp.Request.URL.String()
|
|
||||||
}
|
|
||||||
curl.reqOrigin = req
|
|
||||||
curl.respOrigin = resp
|
|
||||||
curl.Location, _ = resp.Location()
|
|
||||||
curl.RespHttpCode = resp.StatusCode
|
|
||||||
curl.RespHeader = resp.Header
|
|
||||||
curl.RespCookies = resp.Cookies()
|
|
||||||
curl.RecvContentLength = resp.ContentLength
|
|
||||||
readFunc := func(reader io.ReadCloser, writer io.Writer) error {
|
|
||||||
lengthall := resp.ContentLength
|
|
||||||
defer reader.Close()
|
|
||||||
var lengthsum int
|
|
||||||
buf := make([]byte, 65535)
|
|
||||||
for {
|
|
||||||
n, err := reader.Read(buf)
|
|
||||||
if n != 0 {
|
|
||||||
_, err := writer.Write(buf[:n])
|
|
||||||
lengthsum += n
|
|
||||||
if curl.Process != nil {
|
|
||||||
go curl.Process(float64(lengthsum) / float64(lengthall) * 100.00)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
return err
|
|
||||||
} else if err == io.EOF {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if curl.WriteRecvData {
|
|
||||||
buf := bytes.NewBuffer([]byte{})
|
|
||||||
err = readFunc(resp.Body, buf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
curl.RecvData = buf.Bytes()
|
|
||||||
} else {
|
|
||||||
curl.respReader = resp.Body
|
|
||||||
}
|
|
||||||
if curl.RecvIo != nil {
|
|
||||||
if curl.WriteRecvData {
|
|
||||||
_, err = curl.RecvIo.Write(curl.RecvData)
|
|
||||||
} else {
|
|
||||||
err = readFunc(resp.Body, curl.RecvIo)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return curl, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RespBodyReader Only works when WriteRecvData set to false
|
|
||||||
func (curl *Request) RespBodyReader() io.ReadCloser {
|
|
||||||
return curl.respReader
|
|
||||||
}
|
|
||||||
|
|
||||||
func netcurl(curl Request) (*http.Request, *http.Response, error) {
|
|
||||||
var req *http.Request
|
|
||||||
var err error
|
|
||||||
if curl.Method == "" {
|
|
||||||
return nil, nil, errors.New("Error Method Not Entered")
|
|
||||||
}
|
|
||||||
if curl.PostBuffer != nil {
|
|
||||||
req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer)
|
|
||||||
} else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 {
|
|
||||||
req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer)
|
|
||||||
} else {
|
|
||||||
req, err = http.NewRequest(curl.Method, curl.Url, nil)
|
|
||||||
}
|
|
||||||
if curl.Queries != nil {
|
|
||||||
sid := req.URL.Query()
|
|
||||||
for k, v := range curl.Queries {
|
|
||||||
sid.Add(k, v)
|
|
||||||
}
|
|
||||||
req.URL.RawQuery = sid.Encode()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
req.Header = curl.ReqHeader
|
|
||||||
if len(curl.ReqCookies) != 0 {
|
|
||||||
for _, v := range curl.ReqCookies {
|
|
||||||
req.AddCookie(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if curl.Proxy != "" {
|
|
||||||
purl, err := url.Parse(curl.Proxy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
curl.CustomTransport.Proxy = http.ProxyURL(purl)
|
|
||||||
}
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: curl.CustomTransport,
|
|
||||||
}
|
|
||||||
if curl.DisableRedirect {
|
|
||||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
|
|
||||||
return req, resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func UrlEncodeRaw(str string) string {
|
|
||||||
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
|
||||||
return strs
|
|
||||||
}
|
|
||||||
|
|
||||||
func UrlEncode(str string) string {
|
|
||||||
return url.QueryEscape(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UrlDecode(str string) (string, error) {
|
|
||||||
return url.QueryUnescape(str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildQuery(queryData map[string]string) string {
|
|
||||||
query := url.Values{}
|
|
||||||
for k, v := range queryData {
|
|
||||||
query.Add(k, v)
|
|
||||||
}
|
|
||||||
return query.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BuildPostForm(queryMap map[string]string) []byte {
|
|
||||||
query := url.Values{}
|
|
||||||
for k, v := range queryMap {
|
|
||||||
query.Add(k, v)
|
|
||||||
}
|
|
||||||
return []byte(query.Encode())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Request) Resopnse() *http.Response {
|
|
||||||
return r.respOrigin
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Request) Request() *http.Request {
|
|
||||||
return r.reqOrigin
|
|
||||||
}
|
|
||||||
+147
@@ -0,0 +1,147 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultClient *Client
|
||||||
|
defaultHTTPClient *http.Client
|
||||||
|
defaultClientOnce sync.Once
|
||||||
|
defaultHTTPOnce sync.Once
|
||||||
|
|
||||||
|
defaultMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultClient 获取默认 Client(单例)
|
||||||
|
func DefaultClient() *Client {
|
||||||
|
defaultMu.RLock()
|
||||||
|
if defaultClient != nil {
|
||||||
|
c := defaultClient
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
|
||||||
|
defaultClientOnce.Do(func() {
|
||||||
|
c := NewClientNoErr()
|
||||||
|
defaultMu.Lock()
|
||||||
|
defaultClient = c
|
||||||
|
defaultMu.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
defaultMu.RLock()
|
||||||
|
c := defaultClient
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHTTPClient 获取默认 http.Client(单例)
|
||||||
|
func DefaultHTTPClient() *http.Client {
|
||||||
|
defaultMu.RLock()
|
||||||
|
if defaultHTTPClient != nil {
|
||||||
|
c := defaultHTTPClient
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
|
||||||
|
defaultHTTPOnce.Do(func() {
|
||||||
|
c := &http.Client{
|
||||||
|
Transport: &Transport{
|
||||||
|
base: &http.Transport{
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: 0, // 由请求级控制超时
|
||||||
|
}
|
||||||
|
defaultMu.Lock()
|
||||||
|
defaultHTTPClient = c
|
||||||
|
defaultMu.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
defaultMu.RLock()
|
||||||
|
c := defaultHTTPClient
|
||||||
|
defaultMu.RUnlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultClient 设置默认 Client
|
||||||
|
func SetDefaultClient(client *Client) {
|
||||||
|
defaultMu.Lock()
|
||||||
|
defer defaultMu.Unlock()
|
||||||
|
|
||||||
|
defaultClient = client
|
||||||
|
// 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖
|
||||||
|
defaultClientOnce.Do(func() {})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultHTTPClient 设置默认 http.Client
|
||||||
|
func SetDefaultHTTPClient(client *http.Client) {
|
||||||
|
defaultMu.Lock()
|
||||||
|
defer defaultMu.Unlock()
|
||||||
|
|
||||||
|
defaultHTTPClient = client
|
||||||
|
// 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖
|
||||||
|
defaultHTTPOnce.Do(func() {})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 发送 GET 请求(使用默认 Client)
|
||||||
|
func Get(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Get(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post 发送 POST 请求(使用默认 Client)
|
||||||
|
func Post(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Post(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 发送 PUT 请求(使用默认 Client)
|
||||||
|
func Put(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Put(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 发送 DELETE 请求(使用默认 Client)
|
||||||
|
func Delete(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Delete(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head 发送 HEAD 请求(使用默认 Client)
|
||||||
|
func Head(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Head(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch 发送 PATCH 请求(使用默认 Client)
|
||||||
|
func Patch(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Patch(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options 发送 OPTIONS 请求(使用默认 Client)
|
||||||
|
func Options(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
return DefaultClient().Options(url, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace 发送 TRACE 请求(使用默认 Client)
|
||||||
|
func Trace(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 发送 CONNECT 请求(使用默认 Client)
|
||||||
|
func Connect(url string, opts ...RequestOpt) (*Response, error) {
|
||||||
|
req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return req.Do()
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithRawRequestNil(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", "GET", WithRawRequest(nil))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when WithRawRequest(nil)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetHeadersDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET")
|
||||||
|
headers := http.Header{
|
||||||
|
"X-Test": []string{"v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetHeaders(headers)
|
||||||
|
headers.Set("X-Test", "v2")
|
||||||
|
|
||||||
|
if got := req.GetHeader("X-Test"); got != "v1" {
|
||||||
|
t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetQueriesDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET")
|
||||||
|
queries := map[string][]string{
|
||||||
|
"k": []string{"v1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetQueries(queries)
|
||||||
|
queries["k"][0] = "v2"
|
||||||
|
queries["k"] = append(queries["k"], "v3")
|
||||||
|
|
||||||
|
got := req.config.Queries["k"]
|
||||||
|
if len(got) != 1 || got[0] != "v1" {
|
||||||
|
t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetFormDataDefensiveCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "POST")
|
||||||
|
form := map[string][]string{
|
||||||
|
"name": []string{"alice"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetFormData(form)
|
||||||
|
form["name"][0] = "bob"
|
||||||
|
form["name"] = append(form["name"], "carol")
|
||||||
|
|
||||||
|
got := req.config.Body.FormData["name"]
|
||||||
|
if len(got) != 1 || got[0] != "alice" {
|
||||||
|
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithBodyDefensiveCopy(t *testing.T) {
|
||||||
|
body := []byte("hello")
|
||||||
|
|
||||||
|
req, err := NewRequest("http://example.com", "POST", WithBody(body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body[0] = 'j'
|
||||||
|
if string(req.config.Body.Bytes) != "hello" {
|
||||||
|
t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithFormDataDefensiveCopy(t *testing.T) {
|
||||||
|
form := map[string][]string{
|
||||||
|
"name": []string{"alice"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := NewRequest("http://example.com", "POST", WithFormData(form))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
form["name"][0] = "bob"
|
||||||
|
form["name"] = append(form["name"], "carol")
|
||||||
|
got := req.config.Body.FormData["name"]
|
||||||
|
if len(got) != 1 || got[0] != "alice" {
|
||||||
|
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCustomIPDefensiveCopy(t *testing.T) {
|
||||||
|
ips := []string{"1.1.1.1", "8.8.8.8"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips)
|
||||||
|
|
||||||
|
ips[0] = "9.9.9.9"
|
||||||
|
if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" {
|
||||||
|
t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetCustomDNSDefensiveCopy(t *testing.T) {
|
||||||
|
servers := []string{"8.8.8.8", "1.1.1.1"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers)
|
||||||
|
|
||||||
|
servers[0] = "9.9.9.9"
|
||||||
|
if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" {
|
||||||
|
t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,238 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.beginManualDNS()
|
||||||
|
defer traceState.endManualDNS()
|
||||||
|
traceState.dnsStart(TraceDNSStartInfo{Host: host})
|
||||||
|
}
|
||||||
|
ipAddrs, err := lookup()
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.dnsDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: append([]net.IPAddr(nil), ipAddrs...),
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return ipAddrs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = &RequestContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var addrs []string
|
||||||
|
|
||||||
|
if len(reqCtx.CustomIP) > 0 {
|
||||||
|
for _, ip := range reqCtx.CustomIP {
|
||||||
|
addrs = append(addrs, joinResolvedHostPort(ip, port))
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ipAddrs []net.IPAddr
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if reqCtx.LookupIPFn != nil {
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return reqCtx.LookupIPFn(ctx, host)
|
||||||
|
})
|
||||||
|
} else if len(reqCtx.CustomDNS) > 0 {
|
||||||
|
dialTimeout := reqCtx.DialTimeout
|
||||||
|
if dialTimeout == 0 {
|
||||||
|
dialTimeout = DefaultDialTimeout
|
||||||
|
}
|
||||||
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||||
|
resolver := &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var lastErr error
|
||||||
|
for _, dnsServer := range reqCtx.CustomDNS {
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return resolver.LookupIPAddr(ctx, host)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
|
||||||
|
return net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "lookup ip")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipAddr := range ipAddrs {
|
||||||
|
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinResolvedHostPort(host, port string) string {
|
||||||
|
if port == "" {
|
||||||
|
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
|
||||||
|
return "[" + host + "]"
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
|
||||||
|
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// 提取配置
|
||||||
|
reqCtx := getRequestContext(ctx)
|
||||||
|
traceState := getTraceState(ctx)
|
||||||
|
|
||||||
|
dialTimeout := reqCtx.DialTimeout
|
||||||
|
if dialTimeout == 0 {
|
||||||
|
dialTimeout = DefaultDialTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析地址
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "split host port")
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试连接所有地址
|
||||||
|
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||||
|
var lastErr error
|
||||||
|
for _, addr := range addrs {
|
||||||
|
conn, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, wrapError(lastErr, "dial all addresses failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("no addresses to dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultDialTLSFunc 默认 TLS Dial 函数
|
||||||
|
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// 先建立 TCP 连接
|
||||||
|
conn, err := defaultDialFunc(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 TLS 配置
|
||||||
|
reqCtx := getRequestContext(ctx)
|
||||||
|
traceState := getTraceState(ctx)
|
||||||
|
tlsConfig := reqCtx.TLSConfig
|
||||||
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
serverName := tlsConfig.ServerName
|
||||||
|
if serverName == "" {
|
||||||
|
serverName = reqCtx.TLSServerName
|
||||||
|
}
|
||||||
|
if serverName == "" && !tlsConfig.InsecureSkipVerify {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
if idx := strings.LastIndex(addr, ":"); idx > 0 {
|
||||||
|
host = addr[:idx]
|
||||||
|
} else {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
serverName = host
|
||||||
|
}
|
||||||
|
if serverName != "" && tlsConfig.ServerName != serverName {
|
||||||
|
tlsConfig = tlsConfig.Clone() // 避免修改原 config
|
||||||
|
tlsConfig.ServerName = serverName
|
||||||
|
}
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.markCustomTLS()
|
||||||
|
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行 TLS 握手
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
_ = conn.SetDeadline(deadline)
|
||||||
|
defer conn.SetDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Client(conn, tlsConfig)
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
return nil, wrapError(err, "tls handshake")
|
||||||
|
}
|
||||||
|
if traceState != nil {
|
||||||
|
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
Network: network,
|
||||||
|
Addr: addr,
|
||||||
|
ServerName: serverName,
|
||||||
|
ConnectionState: tlsConn.ConnectionState(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// defaultProxyFunc 默认代理函数
|
||||||
|
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
reqCtx := getRequestContext(req.Context())
|
||||||
|
if reqCtx.Proxy == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(reqCtx.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "parse proxy url")
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
+103
@@ -0,0 +1,103 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestCustomIP(t *testing.T) {
|
||||||
|
customIPs := []string{"1.2.3.4", "5.6.7.8"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetCustomIP(customIPs)
|
||||||
|
|
||||||
|
if len(req.config.DNS.CustomIP) != 2 {
|
||||||
|
t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, ip := range req.config.DNS.CustomIP {
|
||||||
|
if ip != customIPs[i] {
|
||||||
|
t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomIPInvalid(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetCustomIP([]string{"invalid-ip"})
|
||||||
|
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Error("Expected error for invalid IP, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomDNS(t *testing.T) {
|
||||||
|
dnsServers := []string{"8.8.8.8", "1.1.1.1"}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetCustomDNS(dnsServers)
|
||||||
|
|
||||||
|
if len(req.config.DNS.CustomDNS) != 2 {
|
||||||
|
t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomDNSInvalid(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetCustomDNS([]string{"invalid-dns"})
|
||||||
|
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Error("Expected error for invalid DNS, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestLookupFunc(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
called = true
|
||||||
|
return []net.IPAddr{
|
||||||
|
{IP: net.ParseIP("1.2.3.4")},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetLookupFunc(lookupFunc)
|
||||||
|
|
||||||
|
if req.config.DNS.LookupFunc == nil {
|
||||||
|
t.Error("LookupFunc not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function to verify it works
|
||||||
|
ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("LookupFunc error: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("LookupFunc was not called")
|
||||||
|
}
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Errorf("IPs length = %v; want 1", len(ips))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSPriority(t *testing.T) {
|
||||||
|
// CustomIP should have highest priority
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetCustomIP([]string{"1.2.3.4"}).
|
||||||
|
SetCustomDNS([]string{"8.8.8.8"}).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// CustomIP should be set
|
||||||
|
if len(req.config.DNS.CustomIP) == 0 {
|
||||||
|
t.Error("CustomIP should be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Others should also be set (but CustomIP takes priority in actual use)
|
||||||
|
if len(req.config.DNS.CustomDNS) == 0 {
|
||||||
|
t.Error("CustomDNS should be set")
|
||||||
|
}
|
||||||
|
if req.config.DNS.LookupFunc == nil {
|
||||||
|
t.Error("LookupFunc should be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIP(b *testing.B) {
|
||||||
|
server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"}))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(b, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithProxy(proxy.URL),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithSkipTLSVerify(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) {
|
||||||
|
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(b, server, "bench-user-tls.test")
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
resp, err := client.Get(targetURL,
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
_, _ = resp.Body().Bytes()
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkTargetURL(tb testing.TB, rawURL, host string) string {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("url.Parse() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := parsed.Port()
|
||||||
|
if port == "" {
|
||||||
|
switch parsed.Scheme {
|
||||||
|
case "https":
|
||||||
|
port = "443"
|
||||||
|
default:
|
||||||
|
port = "80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathWithQuery(path, rawQuery string) string {
|
||||||
|
if path == "" {
|
||||||
|
path = "/"
|
||||||
|
}
|
||||||
|
if rawQuery == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return path + "?" + rawQuery
|
||||||
|
}
|
||||||
@@ -0,0 +1,257 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrInvalidMethod 无效的 HTTP 方法
|
||||||
|
ErrInvalidMethod = errors.New("starnet: invalid HTTP method")
|
||||||
|
|
||||||
|
// ErrInvalidURL 无效的 URL
|
||||||
|
ErrInvalidURL = errors.New("starnet: invalid URL")
|
||||||
|
|
||||||
|
// ErrInvalidIP 无效的 IP 地址
|
||||||
|
ErrInvalidIP = errors.New("starnet: invalid IP address")
|
||||||
|
|
||||||
|
// ErrInvalidDNS 无效的 DNS 服务器
|
||||||
|
ErrInvalidDNS = errors.New("starnet: invalid DNS server")
|
||||||
|
|
||||||
|
// ErrNilClient HTTP Client 为 nil
|
||||||
|
ErrNilClient = errors.New("starnet: http client is nil")
|
||||||
|
|
||||||
|
// ErrNilReader Reader 为 nil
|
||||||
|
ErrNilReader = errors.New("starnet: reader is nil")
|
||||||
|
|
||||||
|
// ErrFileNotFound 文件不存在
|
||||||
|
ErrFileNotFound = errors.New("starnet: file not found")
|
||||||
|
|
||||||
|
// ErrRequestNotPrepared 请求未准备好
|
||||||
|
ErrRequestNotPrepared = errors.New("starnet: request not prepared")
|
||||||
|
|
||||||
|
// ErrBodyAlreadyConsumed Body 已被消费
|
||||||
|
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
|
||||||
|
|
||||||
|
// ErrRespBodyTooLarge 响应体超过允许上限
|
||||||
|
ErrRespBodyTooLarge = errors.New("starnet: response body too large")
|
||||||
|
|
||||||
|
// ErrPingInvalidTimeout ping 超时参数无效
|
||||||
|
ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout")
|
||||||
|
|
||||||
|
// ErrPingPermissionDenied ping 需要更高权限(raw socket)
|
||||||
|
ErrPingPermissionDenied = errors.New("starnet: ping permission denied")
|
||||||
|
|
||||||
|
// ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持
|
||||||
|
ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported")
|
||||||
|
|
||||||
|
// ErrPingNoResolvedTarget ping 目标无法解析为可用地址
|
||||||
|
ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved")
|
||||||
|
)
|
||||||
|
|
||||||
|
// wrapError 包装错误,添加上下文信息
|
||||||
|
func wrapError(err error, format string, args ...interface{}) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf(format, args...)
|
||||||
|
return fmt.Errorf("%s: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNilConn indicates a nil net.Conn argument.
|
||||||
|
ErrNilConn = errors.New("starnet: nil connection")
|
||||||
|
|
||||||
|
// ErrTLSSniffFailed indicates TLS sniffing/parsing failed before handshake setup.
|
||||||
|
ErrTLSSniffFailed = errors.New("starnet: tls sniff failed")
|
||||||
|
|
||||||
|
// ErrTLSConfigSelectionFailed indicates dynamic TLS config selection failed.
|
||||||
|
ErrTLSConfigSelectionFailed = errors.New("starnet: tls config selection failed")
|
||||||
|
|
||||||
|
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
|
||||||
|
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
|
||||||
|
|
||||||
|
// ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP.
|
||||||
|
ErrNotTLS = errors.New("starnet: connection is not TLS")
|
||||||
|
|
||||||
|
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
|
||||||
|
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorKind is a normalized high-level category for request errors.
|
||||||
|
type ErrorKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ErrorKindNone ErrorKind = "none"
|
||||||
|
ErrorKindCanceled ErrorKind = "canceled"
|
||||||
|
ErrorKindTimeout ErrorKind = "timeout"
|
||||||
|
ErrorKindDNS ErrorKind = "dns"
|
||||||
|
ErrorKindTLS ErrorKind = "tls"
|
||||||
|
ErrorKindProxy ErrorKind = "proxy"
|
||||||
|
ErrorKindOther ErrorKind = "other"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsCanceled reports whether err is a cancellation-related error.
|
||||||
|
func IsCanceled(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "context canceled") ||
|
||||||
|
strings.Contains(msg, "operation was canceled") ||
|
||||||
|
strings.Contains(msg, "request canceled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClassifyError maps low-level errors to a stable category for business handling.
|
||||||
|
func ClassifyError(err error) ErrorKind {
|
||||||
|
if err == nil {
|
||||||
|
return ErrorKindNone
|
||||||
|
}
|
||||||
|
if IsCanceled(err) {
|
||||||
|
return ErrorKindCanceled
|
||||||
|
}
|
||||||
|
if IsProxy(err) {
|
||||||
|
return ErrorKindProxy
|
||||||
|
}
|
||||||
|
if IsDNS(err) {
|
||||||
|
return ErrorKindDNS
|
||||||
|
}
|
||||||
|
if IsTLS(err) {
|
||||||
|
return ErrorKindTLS
|
||||||
|
}
|
||||||
|
if IsTimeout(err) {
|
||||||
|
return ErrorKindTimeout
|
||||||
|
}
|
||||||
|
return ErrorKindOther
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTimeout reports whether err is a timeout-related error.
|
||||||
|
func IsTimeout(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uerr *url.Error
|
||||||
|
if errors.As(err, &uerr) && uerr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var nerr net.Error
|
||||||
|
if errors.As(err, &nerr) && nerr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDNS reports whether err is a DNS resolution related error.
|
||||||
|
func IsDNS(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var derr *net.DNSError
|
||||||
|
if errors.As(err, &derr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
if strings.Contains(msg, "no such host") ||
|
||||||
|
strings.Contains(msg, "server misbehaving") ||
|
||||||
|
strings.Contains(msg, "temporary failure in name resolution") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(msg, "lookup ") &&
|
||||||
|
(strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTLS reports whether err is TLS/Certificate related.
|
||||||
|
func IsTLS(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) ||
|
||||||
|
errors.Is(err, ErrTLSSniffFailed) || errors.Is(err, ErrTLSConfigSelectionFailed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var recErr tls.RecordHeaderError
|
||||||
|
if errors.As(err, &recErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uaErr x509.UnknownAuthorityError
|
||||||
|
if errors.As(err, &uaErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var hnErr x509.HostnameError
|
||||||
|
if errors.As(err, &hnErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var certErr x509.CertificateInvalidError
|
||||||
|
if errors.As(err, &certErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var rootsErr x509.SystemRootsError
|
||||||
|
if errors.As(err, &rootsErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProxy reports whether err is proxy related.
|
||||||
|
func IsProxy(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if isProxyMessage(strings.ToLower(err.Error())) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var uerr *url.Error
|
||||||
|
if errors.As(err, &uerr) {
|
||||||
|
if strings.Contains(strings.ToLower(uerr.Op), "proxy") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProxyMessage(msg string) bool {
|
||||||
|
return strings.Contains(msg, "proxyconnect") ||
|
||||||
|
strings.Contains(msg, "proxy error") ||
|
||||||
|
strings.Contains(msg, "proxy authentication required") ||
|
||||||
|
strings.Contains(msg, "proxy: unknown scheme") ||
|
||||||
|
strings.Contains(msg, "socks connect") ||
|
||||||
|
strings.Contains(msg, "socks5")
|
||||||
|
}
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeoutErr struct{}
|
||||||
|
|
||||||
|
func (timeoutErr) Error() string { return "i/o timeout" }
|
||||||
|
func (timeoutErr) Timeout() bool { return true }
|
||||||
|
func (timeoutErr) Temporary() bool { return true }
|
||||||
|
|
||||||
|
func TestIsTimeout(t *testing.T) {
|
||||||
|
if !IsTimeout(context.DeadlineExceeded) {
|
||||||
|
t.Fatal("context deadline should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := &url.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com",
|
||||||
|
Err: timeoutErr{},
|
||||||
|
}
|
||||||
|
if !IsTimeout(uerr) {
|
||||||
|
t.Fatal("url timeout error should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) {
|
||||||
|
t.Fatal("wrapped timeout should be timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsTimeout(errors.New("plain error")) {
|
||||||
|
t.Fatal("plain error must not be timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDNS(t *testing.T) {
|
||||||
|
dnsErr := &net.DNSError{
|
||||||
|
Err: "no such host",
|
||||||
|
Name: "example.invalid",
|
||||||
|
IsNotFound: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(dnsErr) {
|
||||||
|
t.Fatal("dns error should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) {
|
||||||
|
t.Fatal("wrapped dns error should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsDNS(errors.New("lookup example.invalid: no such host")) {
|
||||||
|
t.Fatal("lookup no such host should be dns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsDNS(errors.New("connection reset by peer")) {
|
||||||
|
t.Fatal("non dns error should not be dns")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTLS(t *testing.T) {
|
||||||
|
tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"}
|
||||||
|
if !IsTLS(tlsErr) {
|
||||||
|
t.Fatal("tls record header error should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) {
|
||||||
|
t.Fatal("wrapped tls error should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(errors.New("x509: certificate signed by unknown authority")) {
|
||||||
|
t.Fatal("x509 error text should be tls")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsTLS(ErrNotTLS) {
|
||||||
|
t.Fatal("ErrNotTLS should be tls related")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsTLS(errors.New("plain error")) {
|
||||||
|
t.Fatal("plain error should not be tls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxy(t *testing.T) {
|
||||||
|
raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused")
|
||||||
|
if !IsProxy(raw) {
|
||||||
|
t.Fatal("proxyconnect error should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := &url.Error{
|
||||||
|
Op: "Get",
|
||||||
|
URL: "http://example.com",
|
||||||
|
Err: raw,
|
||||||
|
}
|
||||||
|
if !IsProxy(uerr) {
|
||||||
|
t.Fatal("wrapped proxy error should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
opErr := &net.OpError{
|
||||||
|
Op: "proxyconnect",
|
||||||
|
Net: "tcp",
|
||||||
|
Err: errors.New("connect failed"),
|
||||||
|
}
|
||||||
|
if !IsProxy(opErr) {
|
||||||
|
t.Fatal("net.OpError proxyconnect should be proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) {
|
||||||
|
t.Fatal("non proxy dial error should not be proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsCanceled(t *testing.T) {
|
||||||
|
if !IsCanceled(context.Canceled) {
|
||||||
|
t.Fatal("context canceled should be canceled")
|
||||||
|
}
|
||||||
|
if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) {
|
||||||
|
t.Fatal("wrapped context canceled should be canceled")
|
||||||
|
}
|
||||||
|
if IsCanceled(context.DeadlineExceeded) {
|
||||||
|
t.Fatal("deadline exceeded must not be canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyError(t *testing.T) {
|
||||||
|
dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true}
|
||||||
|
tlsErr := tls.RecordHeaderError{Msg: "bad tls record"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want ErrorKind
|
||||||
|
}{
|
||||||
|
{name: "nil", err: nil, want: ErrorKindNone},
|
||||||
|
{name: "canceled", err: context.Canceled, want: ErrorKindCanceled},
|
||||||
|
{name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy},
|
||||||
|
{name: "dns", err: dnsErr, want: ErrorKindDNS},
|
||||||
|
{name: "tls", err: tlsErr, want: ErrorKindTLS},
|
||||||
|
{name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout},
|
||||||
|
{name: "other", err: errors.New("boom"), want: ErrorKindOther},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ClassifyError(tt.err); got != tt.want {
|
||||||
|
t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
+256
@@ -0,0 +1,256 @@
|
|||||||
|
package starnet_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleGet() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hello, World!"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
fmt.Println(body)
|
||||||
|
// Output: Hello, World!
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExamplePost() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Posted"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.Post(server.URL,
|
||||||
|
starnet.WithBodyString("test data"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
fmt.Println(body)
|
||||||
|
// Output: Posted
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleNewSimpleRequest() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := starnet.NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetHeader("X-Custom", "value").
|
||||||
|
AddQuery("name", "test")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode)
|
||||||
|
// Output: 200
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleClient_Get() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Client GET"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := starnet.NewClientNoErr(
|
||||||
|
starnet.WithTimeout(10*time.Second),
|
||||||
|
starnet.WithUserAgent("MyApp/1.0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
fmt.Println(body)
|
||||||
|
// Output: Client GET
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_SetJSON() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{Name: "John", Email: "john@example.com"}
|
||||||
|
|
||||||
|
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
|
||||||
|
SetJSON(user).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
fmt.Println(result["status"])
|
||||||
|
// Output: ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_AddFormData() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseForm()
|
||||||
|
fmt.Fprintf(w, "name=%s", r.FormValue("name"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
|
||||||
|
AddFormData("name", "John").
|
||||||
|
AddFormData("age", "30").
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
fmt.Println(body)
|
||||||
|
// Output: name=John
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_SetSkipTLSVerify() {
|
||||||
|
// This example shows how to skip TLS verification
|
||||||
|
// Useful for testing with self-signed certificates
|
||||||
|
|
||||||
|
req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET").
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
// In a real scenario, you would call req.Do()
|
||||||
|
fmt.Println(req.Method())
|
||||||
|
// Output: GET
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_Clone() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
baseReq := starnet.NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetHeader("X-API-Key", "secret")
|
||||||
|
|
||||||
|
// Clone and modify
|
||||||
|
req1 := baseReq.Clone().AddQuery("page", "1")
|
||||||
|
req2 := baseReq.Clone().AddQuery("page", "2")
|
||||||
|
|
||||||
|
resp1, _ := req1.Do()
|
||||||
|
resp2, _ := req2.Do()
|
||||||
|
|
||||||
|
defer resp1.Close()
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp1.StatusCode, resp2.StatusCode)
|
||||||
|
// Output: 200 200
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleClient_SetDefaultSkipTLSVerify() {
|
||||||
|
client := starnet.NewClientNoErr()
|
||||||
|
client.SetDefaultSkipTLSVerify(true)
|
||||||
|
|
||||||
|
// All requests from this client will skip TLS verification
|
||||||
|
// unless overridden at request level
|
||||||
|
|
||||||
|
fmt.Println("Client configured")
|
||||||
|
// Output: Client configured
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleWithTimeout() {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.Get(server.URL,
|
||||||
|
starnet.WithTimeout(200*time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode)
|
||||||
|
// Output: 200
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleWithRetry() {
|
||||||
|
var hits int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.Get(server.URL,
|
||||||
|
starnet.WithRetry(2,
|
||||||
|
starnet.WithRetryBackoff(0, 0, 1),
|
||||||
|
starnet.WithRetryJitter(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
|
||||||
|
// Output: 200 3
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleRequest_SetRetry() {
|
||||||
|
var hits int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello"). // 可重放 body 才能安全重试
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryIdempotentOnly(false).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
|
||||||
|
// Output: 200 2
|
||||||
|
}
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestAddFileStream(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseMultipartForm(10 << 20) // 10 MB
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, header, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FormFile error: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
content, _ := io.ReadAll(file)
|
||||||
|
w.Write([]byte(header.Filename + ":" + string(content)))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
fileContent := "test file content"
|
||||||
|
reader := strings.NewReader(fileContent)
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").
|
||||||
|
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
expected := "test.txt:" + fileContent
|
||||||
|
if body != expected {
|
||||||
|
t.Errorf("Body = %v; want %v", body, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAddFileWithFormData(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseMultipartForm(10 << 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check form field
|
||||||
|
name := r.FormValue("name")
|
||||||
|
if name != "John" {
|
||||||
|
t.Errorf("name = %v; want John", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file
|
||||||
|
file, header, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FormFile error: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
w.Write([]byte("OK:" + header.Filename))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
fileContent := "file data"
|
||||||
|
reader := strings.NewReader(fileContent)
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").
|
||||||
|
AddFormData("name", "John").
|
||||||
|
AddFileStream("file", "document.txt", int64(len(fileContent)), reader)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if !strings.Contains(body, "document.txt") {
|
||||||
|
t.Errorf("Body should contain filename, got: %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUploadProgress(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.ParseMultipartForm(10 << 20)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
progressCalled := false
|
||||||
|
var lastUploaded int64
|
||||||
|
|
||||||
|
fileContent := strings.Repeat("a", 1024*10) // 10KB
|
||||||
|
reader := strings.NewReader(fileContent)
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").
|
||||||
|
SetUploadProgress(func(filename string, uploaded, total int64) {
|
||||||
|
progressCalled = true
|
||||||
|
lastUploaded = uploaded
|
||||||
|
if filename != "test.txt" {
|
||||||
|
t.Errorf("filename = %v; want test.txt", filename)
|
||||||
|
}
|
||||||
|
}).
|
||||||
|
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if !progressCalled {
|
||||||
|
t.Error("Progress callback was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastUploaded != int64(len(fileContent)) {
|
||||||
|
t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequestAddFileFromDisk tests uploading a real file from disk
|
||||||
|
func TestRequestAddFileFromDisk(t *testing.T) {
|
||||||
|
// Create a temporary file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||||
|
fileContent := []byte("test file content from disk")
|
||||||
|
|
||||||
|
err := os.WriteFile(tmpFile, fileContent, 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteFile error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseMultipartForm(10 << 20)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseMultipartForm error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, header, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FormFile error: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
content, _ := io.ReadAll(file)
|
||||||
|
w.Write([]byte(header.Filename + ":" + string(content)))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if !strings.Contains(body, string(fileContent)) {
|
||||||
|
t.Errorf("Body should contain file content, got: %v", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
module b612.me/starnet
|
module b612.me/starnet
|
||||||
|
|
||||||
go 1.16
|
go 1.16
|
||||||
|
|
||||||
require b612.me/stario v0.0.8
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
b612.me/stario v0.0.8 h1:kaA4pszAKLZJm2D9JmiuYSpgjTeE3VaO74vm+H0vBGM=
|
|
||||||
b612.me/stario v0.0.8/go.mod h1:or4ssWcxQSjMeu+hRKEgtp0X517b3zdlEOAms8Qscvw=
|
|
||||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
|
|
||||||
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
|
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
||||||
|
|||||||
+140
@@ -0,0 +1,140 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestHeaders(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
for k, v := range r.Header {
|
||||||
|
if len(v) > 0 {
|
||||||
|
headers[k] = v[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(headers)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetHeader("X-Custom-Header", "value1").
|
||||||
|
AddHeader("X-Multi-Header", "value1").
|
||||||
|
AddHeader("X-Multi-Header", "value2")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var headers map[string]string
|
||||||
|
resp.Body().JSON(&headers)
|
||||||
|
|
||||||
|
if headers["X-Custom-Header"] != "value1" {
|
||||||
|
t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCookies(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookies := make(map[string]string)
|
||||||
|
for _, cookie := range r.Cookies() {
|
||||||
|
cookies[cookie.Name] = cookie.Value
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(cookies)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
AddSimpleCookie("session", "abc123").
|
||||||
|
AddSimpleCookie("user", "john")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var cookies map[string]string
|
||||||
|
resp.Body().JSON(&cookies)
|
||||||
|
|
||||||
|
if cookies["session"] != "abc123" {
|
||||||
|
t.Errorf("session cookie = %v; want abc123", cookies["session"])
|
||||||
|
}
|
||||||
|
if cookies["user"] != "john" {
|
||||||
|
t.Errorf("user cookie = %v; want john", cookies["user"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUserAgent(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(r.UserAgent()))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetUserAgent("CustomAgent/1.0")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != "CustomAgent/1.0" {
|
||||||
|
t.Errorf("User-Agent = %v; want CustomAgent/1.0", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBearerToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
w.Write([]byte(auth))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetBearerToken("mytoken123")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
expected := "Bearer mytoken123"
|
||||||
|
if body != expected {
|
||||||
|
t.Errorf("Authorization = %v; want %v", body, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBasicAuth(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
username, password, ok := r.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte(username + ":" + password))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetBasicAuth("user", "pass")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != "user:pass" {
|
||||||
|
t.Errorf("BasicAuth = %v; want user:pass", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||||||
|
cfg := &tls.Config{}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSConfig(cfg).
|
||||||
|
SetURL("https://other.example")
|
||||||
|
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("unexpected request error: %v", req.Err())
|
||||||
|
}
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||||||
|
cfg := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSConfig(cfg).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if rc.TLSConfig == nil {
|
||||||
|
t.Fatal("expected injected tls config")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == cfg {
|
||||||
|
t.Fatal("expected injected tls config to be cloned")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig.ServerName != "override.example" {
|
||||||
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if rc.TLSConfig == nil {
|
||||||
|
t.Fatal("expected injected tls config")
|
||||||
|
}
|
||||||
|
if rc.TLSConfig.ServerName != "override.example" {
|
||||||
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet)
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
|
||||||
|
t.Fatalf("unexpected request context injection: %#v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := getRequestContext(req.execCtx)
|
||||||
|
if needsDynamicTransport(rc) {
|
||||||
|
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "" {
|
||||||
|
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := req.execCtx.Value(ctxKeyRequestContext)
|
||||||
|
rc, ok := raw.(*RequestContext)
|
||||||
|
if !ok || rc == nil {
|
||||||
|
t.Fatalf("expected aggregated request context, got %#v", raw)
|
||||||
|
}
|
||||||
|
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
|
||||||
|
t.Fatalf("custom ip=%v", rc.CustomIP)
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
|
||||||
|
t.Fatal("expected tls config with skip verify")
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "example.com" {
|
||||||
|
t.Fatalf("default tls server name=%q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Host != "override.example" {
|
||||||
|
t.Fatalf("host=%q", r.Host)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetHost("override.example").
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithHostOverridesRequestHost(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Host != "option.example" {
|
||||||
|
t.Fatalf("host=%q", r.Host)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := resp.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer got.Close()
|
||||||
|
}
|
||||||
@@ -0,0 +1,258 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 这些测试使用 httpbin.org 作为测试服务
|
||||||
|
// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用
|
||||||
|
|
||||||
|
func skipIfNoIntegration(t *testing.T) {
|
||||||
|
if os.Getenv("STARNET_INTEGRATION_TEST") != "1" {
|
||||||
|
t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinGet(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
resp, err := Get("https://httpbin.org/get",
|
||||||
|
WithQuery("name", "starnet"),
|
||||||
|
WithQuery("version", "1.0"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
args, ok := result["args"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("args not found in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if args["name"] != "starnet" {
|
||||||
|
t.Errorf("args[name] = %v; want starnet", args["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinPost(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
type PostData struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
data := PostData{
|
||||||
|
Name: "John Doe",
|
||||||
|
Email: "john@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := Post("https://httpbin.org/post", WithJSON(data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, ok := result["json"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("json not found in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if jsonData["name"] != data.Name {
|
||||||
|
t.Errorf("name = %v; want %v", jsonData["name"], data.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinHeaders(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
resp, err := Get("https://httpbin.org/headers",
|
||||||
|
WithHeader("X-Custom-Header", "test-value"),
|
||||||
|
WithUserAgent("Starnet-Test/1.0"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, ok := result["headers"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("headers not found in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers["X-Custom-Header"] != "test-value" {
|
||||||
|
t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinBasicAuth(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
resp, err := Get("https://httpbin.org/basic-auth/user/passwd",
|
||||||
|
WithBasicAuth("user", "passwd"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["authenticated"] != true {
|
||||||
|
t.Error("authenticated should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinDelay(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
// Test timeout
|
||||||
|
start := time.Now()
|
||||||
|
_, err := Get("https://httpbin.org/delay/3",
|
||||||
|
WithTimeout(1*time.Second))
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected timeout error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if elapsed > 2*time.Second {
|
||||||
|
t.Errorf("Timeout took too long: %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinRedirect(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
// Test with redirect enabled
|
||||||
|
client := NewClientNoErr()
|
||||||
|
resp, err := client.Get("https://httpbin.org/redirect/2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with redirect disabled
|
||||||
|
client.DisableRedirect()
|
||||||
|
resp2, err := client.Get("https://httpbin.org/redirect/2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
if resp2.StatusCode != 302 {
|
||||||
|
t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinCookies(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
// 创建一个禁用重定向的 Client
|
||||||
|
client := NewClientNoErr()
|
||||||
|
client.DisableRedirect()
|
||||||
|
|
||||||
|
resp, err := client.Get("https://httpbin.org/cookies/set?name=value")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
// 现在应该能获取到 Set-Cookie
|
||||||
|
cookies := resp.Cookies()
|
||||||
|
if len(cookies) == 0 {
|
||||||
|
t.Error("Expected cookies in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 cookie
|
||||||
|
found := false
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Name == "name" && cookie.Value == "value" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("Expected cookie 'name=value' not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinUserAgent(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
customUA := "Starnet-Integration-Test/1.0"
|
||||||
|
resp, err := Get("https://httpbin.org/user-agent",
|
||||||
|
WithUserAgent(customUA))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["user-agent"] != customUA {
|
||||||
|
t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationHTTPBinGzip(t *testing.T) {
|
||||||
|
skipIfNoIntegration(t)
|
||||||
|
|
||||||
|
resp, err := Get("https://httpbin.org/gzip")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["gzipped"] != true {
|
||||||
|
t.Error("Response should be gzipped")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
package pingcore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const icmpHeaderLen = 8
|
||||||
|
|
||||||
|
type ICMP struct {
|
||||||
|
Type uint8
|
||||||
|
Code uint8
|
||||||
|
CheckSum uint16
|
||||||
|
Identifier uint16
|
||||||
|
SequenceNum uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Count int
|
||||||
|
Timeout time.Duration
|
||||||
|
Interval time.Duration
|
||||||
|
Deadline time.Time
|
||||||
|
PreferIPv4 bool
|
||||||
|
PreferIPv6 bool
|
||||||
|
SourceIP net.IP
|
||||||
|
PayloadSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Duration time.Duration
|
||||||
|
RecvCount int
|
||||||
|
RemoteIP string
|
||||||
|
}
|
||||||
|
|
||||||
|
var identifierSeed uint32
|
||||||
|
|
||||||
|
func NextIdentifier() uint16 {
|
||||||
|
pid := uint32(os.Getpid() & 0xffff)
|
||||||
|
n := atomic.AddUint32(&identifierSeed, 1)
|
||||||
|
return uint16((pid + n) & 0xffff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Payload(size int) []byte {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload := make([]byte, size)
|
||||||
|
for index := 0; index < len(payload); index++ {
|
||||||
|
payload[index] = byte(index)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
||||||
|
icmp := ICMP{
|
||||||
|
Type: typ,
|
||||||
|
Code: 0,
|
||||||
|
CheckSum: 0,
|
||||||
|
Identifier: identifier,
|
||||||
|
SequenceNum: seq,
|
||||||
|
}
|
||||||
|
buf := MarshalPacket(icmp, payload)
|
||||||
|
icmp.CheckSum = Checksum(buf)
|
||||||
|
return icmp
|
||||||
|
}
|
||||||
|
|
||||||
|
func Checksum(data []byte) uint16 {
|
||||||
|
var (
|
||||||
|
sum uint32
|
||||||
|
length = len(data)
|
||||||
|
index int
|
||||||
|
)
|
||||||
|
for length > 1 {
|
||||||
|
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
||||||
|
index += 2
|
||||||
|
length -= 2
|
||||||
|
}
|
||||||
|
if length > 0 {
|
||||||
|
sum += uint32(data[index]) << 8
|
||||||
|
}
|
||||||
|
for sum>>16 != 0 {
|
||||||
|
sum = (sum & 0xffff) + (sum >> 16)
|
||||||
|
}
|
||||||
|
return uint16(^sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Marshal(icmp ICMP) []byte {
|
||||||
|
return MarshalPacket(icmp, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPacket(icmp ICMP, payload []byte) []byte {
|
||||||
|
buf := make([]byte, icmpHeaderLen+len(payload))
|
||||||
|
buf[0] = icmp.Type
|
||||||
|
buf[1] = icmp.Code
|
||||||
|
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
|
||||||
|
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
|
||||||
|
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
|
||||||
|
copy(buf[icmpHeaderLen:], payload)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
||||||
|
for _, offset := range CandidateICMPOffsets(packet, family) {
|
||||||
|
if offset < 0 || offset+icmpHeaderLen > len(packet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if packet[offset] != expectedType || packet[offset+1] != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func CandidateICMPOffsets(packet []byte, family int) []int {
|
||||||
|
offsets := []int{0}
|
||||||
|
if len(packet) == 0 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
|
||||||
|
version := packet[0] >> 4
|
||||||
|
if version == 4 && len(packet) >= 20 {
|
||||||
|
ihl := int(packet[0]&0x0f) * 4
|
||||||
|
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
|
||||||
|
offsets = append(offsets, ihl)
|
||||||
|
}
|
||||||
|
} else if version == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == 4 && len(packet) >= 20+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 20)
|
||||||
|
}
|
||||||
|
if family == 6 && len(packet) >= 40+icmpHeaderLen {
|
||||||
|
offsets = append(offsets, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
return DedupOffsets(offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DedupOffsets(offsets []int) []int {
|
||||||
|
if len(offsets) <= 1 {
|
||||||
|
return offsets
|
||||||
|
}
|
||||||
|
seen := make(map[int]struct{}, len(offsets))
|
||||||
|
out := make([]int, 0, len(offsets))
|
||||||
|
for _, offset := range offsets {
|
||||||
|
if _, ok := seen[offset]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[offset] = struct{}{}
|
||||||
|
out = append(out, offset)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
||||||
|
if parsed := net.ParseIP(host); parsed != nil {
|
||||||
|
return []*net.IPAddr{{IP: parsed}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var targets []*net.IPAddr
|
||||||
|
var err4 error
|
||||||
|
var err6 error
|
||||||
|
|
||||||
|
if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil {
|
||||||
|
targets = append(targets, ip4)
|
||||||
|
} else {
|
||||||
|
err4 = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil {
|
||||||
|
targets = append(targets, ip6)
|
||||||
|
} else {
|
||||||
|
err6 = err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(targets) > 0 {
|
||||||
|
return OrderTargets(targets, preferIPv4, preferIPv6), nil
|
||||||
|
}
|
||||||
|
if err4 != nil {
|
||||||
|
return nil, err4
|
||||||
|
}
|
||||||
|
if err6 != nil {
|
||||||
|
return nil, err6
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
||||||
|
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
|
||||||
|
return targets
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := make([]*net.IPAddr, 0, len(targets))
|
||||||
|
if preferIPv4 {
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() == nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target.IP != nil && target.IP.To4() != nil {
|
||||||
|
ordered = append(ordered, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package tlssniffercore
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
||||||
|
if base == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
out := base.Clone()
|
||||||
|
ApplyServerTLSOverrides(out, selected)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyServerTLSOverrides(dst, src *tls.Config) {
|
||||||
|
if dst == nil || src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if src.Rand != nil {
|
||||||
|
dst.Rand = src.Rand
|
||||||
|
}
|
||||||
|
if src.Time != nil {
|
||||||
|
dst.Time = src.Time
|
||||||
|
}
|
||||||
|
if len(src.Certificates) > 0 {
|
||||||
|
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
|
||||||
|
}
|
||||||
|
if len(src.NameToCertificate) > 0 {
|
||||||
|
copied := make(map[string]*tls.Certificate, len(src.NameToCertificate))
|
||||||
|
for name, cert := range src.NameToCertificate {
|
||||||
|
copied[name] = cert
|
||||||
|
}
|
||||||
|
dst.NameToCertificate = copied
|
||||||
|
}
|
||||||
|
if src.GetCertificate != nil {
|
||||||
|
dst.GetCertificate = src.GetCertificate
|
||||||
|
}
|
||||||
|
if src.GetClientCertificate != nil {
|
||||||
|
dst.GetClientCertificate = src.GetClientCertificate
|
||||||
|
}
|
||||||
|
if src.GetConfigForClient != nil {
|
||||||
|
dst.GetConfigForClient = src.GetConfigForClient
|
||||||
|
}
|
||||||
|
if src.VerifyPeerCertificate != nil {
|
||||||
|
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
|
||||||
|
}
|
||||||
|
if src.VerifyConnection != nil {
|
||||||
|
dst.VerifyConnection = src.VerifyConnection
|
||||||
|
}
|
||||||
|
if src.RootCAs != nil {
|
||||||
|
dst.RootCAs = src.RootCAs
|
||||||
|
}
|
||||||
|
if len(src.NextProtos) > 0 {
|
||||||
|
dst.NextProtos = append([]string(nil), src.NextProtos...)
|
||||||
|
}
|
||||||
|
if src.ServerName != "" {
|
||||||
|
dst.ServerName = src.ServerName
|
||||||
|
}
|
||||||
|
if src.ClientAuth > dst.ClientAuth {
|
||||||
|
dst.ClientAuth = src.ClientAuth
|
||||||
|
}
|
||||||
|
if src.ClientCAs != nil {
|
||||||
|
dst.ClientCAs = src.ClientCAs
|
||||||
|
}
|
||||||
|
if src.InsecureSkipVerify {
|
||||||
|
dst.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
if len(src.CipherSuites) > 0 {
|
||||||
|
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
|
||||||
|
}
|
||||||
|
if src.PreferServerCipherSuites {
|
||||||
|
dst.PreferServerCipherSuites = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketsDisabled {
|
||||||
|
dst.SessionTicketsDisabled = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketKey != ([32]byte{}) {
|
||||||
|
dst.SessionTicketKey = src.SessionTicketKey
|
||||||
|
}
|
||||||
|
if src.ClientSessionCache != nil {
|
||||||
|
dst.ClientSessionCache = src.ClientSessionCache
|
||||||
|
}
|
||||||
|
if src.UnwrapSession != nil {
|
||||||
|
dst.UnwrapSession = src.UnwrapSession
|
||||||
|
}
|
||||||
|
if src.WrapSession != nil {
|
||||||
|
dst.WrapSession = src.WrapSession
|
||||||
|
}
|
||||||
|
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
|
||||||
|
dst.MinVersion = src.MinVersion
|
||||||
|
}
|
||||||
|
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
|
||||||
|
dst.MaxVersion = src.MaxVersion
|
||||||
|
}
|
||||||
|
if len(src.CurvePreferences) > 0 {
|
||||||
|
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
|
||||||
|
}
|
||||||
|
if src.DynamicRecordSizingDisabled {
|
||||||
|
dst.DynamicRecordSizingDisabled = true
|
||||||
|
}
|
||||||
|
if src.Renegotiation != 0 {
|
||||||
|
dst.Renegotiation = src.Renegotiation
|
||||||
|
}
|
||||||
|
if src.KeyLogWriter != nil {
|
||||||
|
dst.KeyLogWriter = src.KeyLogWriter
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloConfigList) > 0 {
|
||||||
|
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
|
||||||
|
}
|
||||||
|
if src.EncryptedClientHelloRejectionVerify != nil {
|
||||||
|
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
|
||||||
|
}
|
||||||
|
if src.GetEncryptedClientHelloKeys != nil {
|
||||||
|
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloKeys) > 0 {
|
||||||
|
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
package tlssniffercore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientHelloMeta struct {
|
||||||
|
ServerName string
|
||||||
|
LocalAddr net.Addr
|
||||||
|
RemoteAddr net.Addr
|
||||||
|
SupportedProtos []string
|
||||||
|
SupportedVersions []uint16
|
||||||
|
CipherSuites []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type SniffResult struct {
|
||||||
|
IsTLS bool
|
||||||
|
ClientHello *ClientHelloMeta
|
||||||
|
Buffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type Sniffer struct{}
|
||||||
|
|
||||||
|
func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 64 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
||||||
|
meta, isTLS := sniffClientHello(limited, &buf, conn)
|
||||||
|
|
||||||
|
out := SniffResult{
|
||||||
|
IsTLS: isTLS,
|
||||||
|
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
|
||||||
|
}
|
||||||
|
if isTLS {
|
||||||
|
out.ClientHello = meta
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
|
||||||
|
meta := &ClientHelloMeta{
|
||||||
|
LocalAddr: conn.LocalAddr(),
|
||||||
|
RemoteAddr: conn.RemoteAddr(),
|
||||||
|
}
|
||||||
|
|
||||||
|
header, complete := readTLSRecordHeader(reader, buf)
|
||||||
|
if len(header) < 3 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
isTLS := header[0] == 0x16 && header[1] == 0x03
|
||||||
|
if !isTLS {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if len(header) < 5 || !complete {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||||
|
recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen)
|
||||||
|
if !bodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
|
||||||
|
helloBytes := append([]byte(nil), recordBody[4:]...)
|
||||||
|
for len(helloBytes) < helloLen {
|
||||||
|
nextHeader, ok := readTLSRecordHeader(reader, buf)
|
||||||
|
if len(nextHeader) < 5 || !ok {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
|
||||||
|
nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen)
|
||||||
|
if !bodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
helloBytes = append(helloBytes, nextBody...)
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloBody(meta, helloBytes[:helloLen])
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) {
|
||||||
|
return readBufferedBytes(reader, buf, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) {
|
||||||
|
if count <= 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
tmp := make([]byte, count)
|
||||||
|
readN, err := io.ReadFull(reader, tmp)
|
||||||
|
if readN > 0 {
|
||||||
|
buf.Write(tmp[:readN])
|
||||||
|
}
|
||||||
|
return append([]byte(nil), tmp[:readN]...), err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
|
||||||
|
if meta == nil || len(body) < 34 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 2 + 32
|
||||||
|
sessionIDLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+sessionIDLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += sessionIDLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+cipherSuitesLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for index := 0; index+1 < cipherSuitesLen; index += 2 {
|
||||||
|
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2]))
|
||||||
|
}
|
||||||
|
offset += cipherSuitesLen
|
||||||
|
|
||||||
|
if offset >= len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
compressionMethodsLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+compressionMethodsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += compressionMethodsLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+extensionsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
|
||||||
|
for offset := 0; offset+4 <= len(exts); {
|
||||||
|
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
|
||||||
|
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
|
||||||
|
offset += 4
|
||||||
|
if offset+extLen > len(exts) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extData := exts[offset : offset+extLen]
|
||||||
|
offset += extLen
|
||||||
|
|
||||||
|
switch extType {
|
||||||
|
case 0:
|
||||||
|
parseServerNameExtension(meta, extData)
|
||||||
|
case 16:
|
||||||
|
parseALPNExtension(meta, extData)
|
||||||
|
case 43:
|
||||||
|
parseSupportedVersionsExtension(meta, extData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset+3 <= len(list); {
|
||||||
|
nameType := list[offset]
|
||||||
|
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
|
||||||
|
offset += 3
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if nameType == 0 {
|
||||||
|
meta.ServerName = string(list[offset : offset+nameLen])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset < len(list); {
|
||||||
|
nameLen := int(list[offset])
|
||||||
|
offset++
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(data[0])
|
||||||
|
if listLen == 0 || 1+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[1 : 1+listLen]
|
||||||
|
for offset := 0; offset+1 < len(list); offset += 2 {
|
||||||
|
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
|
||||||
|
}
|
||||||
|
}
|
||||||
+112
@@ -0,0 +1,112 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithBody 设置请求体(字节)
|
||||||
|
func WithBody(body []byte) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setBytesBodyConfig(&r.config.Body, body)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBodyString 设置请求体(字符串)
|
||||||
|
func WithBodyString(body string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setBytesBodyConfig(&r.config.Body, []byte(body))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBodyReader 设置请求体(Reader)。
|
||||||
|
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
|
||||||
|
func WithBodyReader(reader io.Reader) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setReaderBodyConfig(&r.config.Body, reader)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithJSON 设置 JSON 请求体
|
||||||
|
func WithJSON(v interface{}) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "marshal json")
|
||||||
|
}
|
||||||
|
r.config.Headers.Set("Content-Type", ContentTypeJSON)
|
||||||
|
setBytesBodyConfig(&r.config.Body, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFormData 设置表单数据
|
||||||
|
func WithFormData(data map[string][]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setFormBodyConfig(&r.config.Body, data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFormDataMap 设置表单数据(简化版)
|
||||||
|
func WithFormDataMap(data map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setFormBodyConfig(&r.config.Body, nil)
|
||||||
|
for key, value := range data {
|
||||||
|
r.config.Body.FormData[key] = []string{value}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddFormData 添加表单数据
|
||||||
|
func WithAddFormData(key, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
ensureFormMode(&r.config.Body)
|
||||||
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFile 添加文件
|
||||||
|
func WithFile(formName, filePath string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
stat, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: stat.Name(),
|
||||||
|
FilePath: filePath,
|
||||||
|
FileSize: stat.Size(),
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFileStream 添加文件流
|
||||||
|
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if reader == nil {
|
||||||
|
return ErrNilReader
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: fileName,
|
||||||
|
FileData: reader,
|
||||||
|
FileSize: size,
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithTimeout 设置请求总超时时间
|
||||||
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
|
// timeout = 0: 不额外设置请求总超时
|
||||||
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
|
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialTimeout 设置连接超时时间
|
||||||
|
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateDialTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithProxy 设置代理
|
||||||
|
func WithProxy(proxy string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateProxy(proxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialFunc 设置自定义 Dial 函数
|
||||||
|
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateDialFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig 设置 TLS 配置
|
||||||
|
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTLSConfig(tlsConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSServerName 设置显式 TLS ServerName/SNI。
|
||||||
|
func WithTLSServerName(serverName string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTLSServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTraceHooks 设置请求 trace 回调。
|
||||||
|
func WithTraceHooks(hooks *TraceHooks) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTraceHooks(hooks))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTraceRecorder 设置请求级 trace 摘要记录器。
|
||||||
|
func WithTraceRecorder(recorder *TraceRecorder) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTraceRecorder(recorder))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSkipTLSVerify 设置是否跳过 TLS 验证
|
||||||
|
func WithSkipTLSVerify(skip bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateSkipTLSVerify(skip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomIP 设置自定义 IP
|
||||||
|
func WithCustomIP(ips []string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateCustomIP(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddCustomIP 添加自定义 IP
|
||||||
|
func WithAddCustomIP(ip string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddCustomIP(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomDNS 设置自定义 DNS 服务器
|
||||||
|
func WithCustomDNS(dnsServers []string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateCustomDNS(dnsServers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAddCustomDNS 添加自定义 DNS 服务器
|
||||||
|
func WithAddCustomDNS(dns string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddCustomDNS(dns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLookupFunc 设置自定义 DNS 解析函数
|
||||||
|
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateLookupFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBasicAuth 设置 Basic 认证
|
||||||
|
func WithBasicAuth(username, password string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateBasicAuth(username, password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithQuery 添加查询参数
|
||||||
|
func WithQuery(key, value string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithQueries 批量添加查询参数
|
||||||
|
func WithQueries(queries map[string]string) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAddQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContentLength 设置 Content-Length
|
||||||
|
func WithContentLength(length int64) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateContentLength(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||||
|
func WithAutoCalcContentLength(auto bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAutoCalcContentLength(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUploadProgress 设置文件上传进度回调
|
||||||
|
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateUploadProgress(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTransport 设置自定义 Transport
|
||||||
|
func WithTransport(transport *http.Transport) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateTransport(transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAutoFetch 设置是否自动获取响应体
|
||||||
|
func WithAutoFetch(auto bool) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateAutoFetch(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRawRequest 设置原始请求
|
||||||
|
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateRawRequest(httpReq))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext 设置 context
|
||||||
|
func WithContext(ctx context.Context) RequestOpt {
|
||||||
|
return requestOptFromMutation(mutateContext(ctx))
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// WithHeader 设置 Header
|
||||||
|
func WithHeader(key, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.config.Headers.Set(key, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHost 设置显式 Host 头覆盖。
|
||||||
|
func WithHost(host string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
setRequestHostConfig(r.config, host)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaders 批量设置 Headers
|
||||||
|
func WithHeaders(headers map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for key, value := range headers {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.config.Headers.Set(key, value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContentType 设置 Content-Type
|
||||||
|
func WithContentType(contentType string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("Content-Type", contentType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUserAgent 设置 User-Agent
|
||||||
|
func WithUserAgent(userAgent string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("User-Agent", userAgent)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBearerToken 设置 Bearer Token
|
||||||
|
func WithBearerToken(token string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Headers.Set("Authorization", "Bearer "+token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCookie 添加 Cookie
|
||||||
|
func WithCookie(name, value, path string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: path,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSimpleCookie 添加简单 Cookie(path 为 /)
|
||||||
|
func WithSimpleCookie(name, value string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCookies 批量添加 Cookies
|
||||||
|
func WithCookies(cookies map[string]string) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for name, value := range cookies {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
+234
@@ -0,0 +1,234 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithJSONOpt(t *testing.T) {
|
||||||
|
type payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
|
||||||
|
t.Fatalf("content-type=%s", ct)
|
||||||
|
}
|
||||||
|
var p payload
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
||||||
|
t.Fatalf("decode err: %v", err)
|
||||||
|
}
|
||||||
|
if p.Name != "alice" || p.Age != 18 {
|
||||||
|
t.Fatalf("payload mismatch: %+v", p)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithFileOpt(t *testing.T) {
|
||||||
|
// temp file + cleanup
|
||||||
|
f, err := os.CreateTemp("", "starnet-upload-*.txt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(f.Name())
|
||||||
|
_, _ = f.WriteString("hello-file")
|
||||||
|
_ = f.Close()
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||||
|
t.Fatalf("parse form err: %v", err)
|
||||||
|
}
|
||||||
|
file, header, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("form file err: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
b, _ := io.ReadAll(file)
|
||||||
|
if header.Filename == "" || string(b) != "hello-file" {
|
||||||
|
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Post(s.URL, WithFile("file", f.Name()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithFileStreamOpt(t *testing.T) {
|
||||||
|
content := "stream-content"
|
||||||
|
reader := strings.NewReader(content)
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||||
|
t.Fatalf("parse form err: %v", err)
|
||||||
|
}
|
||||||
|
file, header, err := r.FormFile("up")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("form file err: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
b, _ := io.ReadAll(file)
|
||||||
|
if header.Filename != "a.txt" || string(b) != content {
|
||||||
|
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithQueryOpt(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("k") != "v" {
|
||||||
|
t.Fatalf("query mismatch: %v", r.URL.Query())
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Get(s.URL, WithQuery("k", "v"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithUploadProgressOpt(t *testing.T) {
|
||||||
|
var called int32
|
||||||
|
var last int64
|
||||||
|
content := strings.Repeat("x", 4096)
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = r.ParseMultipartForm(10 << 20)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Post(s.URL,
|
||||||
|
WithUploadProgress(func(filename string, uploaded, total int64) {
|
||||||
|
atomic.StoreInt32(&called, 1)
|
||||||
|
last = uploaded
|
||||||
|
}),
|
||||||
|
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Post error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&called) == 0 {
|
||||||
|
t.Fatal("progress not called")
|
||||||
|
}
|
||||||
|
if last != int64(len(content)) {
|
||||||
|
t.Fatalf("last uploaded=%d want=%d", last, len(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTransportOpt(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithContextOpt(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := Get(s.URL, WithContext(ctx))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected context timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", req.Err())
|
||||||
|
}
|
||||||
|
if len(req.config.DNS.CustomDNS) != 2 {
|
||||||
|
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithAddCustomIPOpt(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", req.Err())
|
||||||
|
}
|
||||||
|
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
|
||||||
|
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithCustomIPOpt(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
|
||||||
|
if req.Err() != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", req.Err())
|
||||||
|
}
|
||||||
|
if len(req.config.DNS.CustomIP) != 2 {
|
||||||
|
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithDialFuncOpt(t *testing.T) {
|
||||||
|
called := int32(0)
|
||||||
|
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
atomic.StoreInt32(&called, 1)
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
|
||||||
|
if req.config.Network.DialFunc == nil {
|
||||||
|
t.Fatal("dial func not set")
|
||||||
|
}
|
||||||
|
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
|
||||||
|
if atomic.LoadInt32(&called) == 0 {
|
||||||
|
t.Fatal("dial func not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithDialTimeoutOpt(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
|
||||||
|
if req.config.Network.DialTimeout != 123*time.Millisecond {
|
||||||
|
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,110 +1,375 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"context"
|
||||||
"encoding/binary"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet/internal/pingcore"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICMP struct {
|
const (
|
||||||
Type uint8
|
icmpTypeEchoReplyV4 = 0
|
||||||
Code uint8
|
icmpTypeEchoRequestV4 = 8
|
||||||
CheckSum uint16
|
icmpTypeEchoRequestV6 = 128
|
||||||
Identifier uint16
|
icmpTypeEchoReplyV6 = 129
|
||||||
SequenceNum uint16
|
|
||||||
|
icmpReadBufSz = 1500
|
||||||
|
|
||||||
|
defaultPingAttemptTimeout = 2 * time.Second
|
||||||
|
defaultPingableCount = 3
|
||||||
|
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
|
||||||
|
)
|
||||||
|
|
||||||
|
type ICMP = pingcore.ICMP
|
||||||
|
|
||||||
|
type pingSocketSpec struct {
|
||||||
|
network string
|
||||||
|
family int
|
||||||
|
requestType uint8
|
||||||
|
replyType uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func getICMP(seq uint16) ICMP {
|
// PingOptions controls ping probing behavior.
|
||||||
icmp := ICMP{
|
type PingOptions = pingcore.Options
|
||||||
Type: 8,
|
|
||||||
Code: 0,
|
|
||||||
CheckSum: 0,
|
|
||||||
Identifier: 0,
|
|
||||||
SequenceNum: seq,
|
|
||||||
}
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
|
||||||
icmp.CheckSum = checkSum(buffer.Bytes())
|
|
||||||
buffer.Reset()
|
|
||||||
|
|
||||||
return icmp
|
type PingResult = pingcore.Result
|
||||||
|
|
||||||
|
func nextPingIdentifier() uint16 {
|
||||||
|
return pingcore.NextIdentifier()
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) {
|
func pingPayload(size int) []byte {
|
||||||
|
return pingcore.Payload(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
||||||
|
return pingcore.BuildICMP(seq, identifier, typ, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
|
||||||
var res PingResult
|
var res PingResult
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return res, wrapError(err, "ping context done")
|
||||||
|
}
|
||||||
|
if destAddr == nil || destAddr.IP == nil {
|
||||||
|
return res, fmt.Errorf("destination ip is nil")
|
||||||
|
}
|
||||||
res.RemoteIP = destAddr.String()
|
res.RemoteIP = destAddr.String()
|
||||||
conn, err := net.DialIP("ip:icmp", nil, destAddr)
|
|
||||||
|
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn, err := net.DialIP(spec.network, localAddr, destAddr)
|
||||||
|
if err != nil {
|
||||||
|
return res, normalizePingDialError(err)
|
||||||
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
var buffer bytes.Buffer
|
|
||||||
binary.Write(&buffer, binary.BigEndian, icmp)
|
|
||||||
|
|
||||||
if _, err := conn.Write(buffer.Bytes()); err != nil {
|
packet := marshalICMPPacket(icmp, payload)
|
||||||
return res, err
|
if _, err := conn.Write(packet); err != nil {
|
||||||
|
return res, wrapError(err, "ping write request")
|
||||||
}
|
}
|
||||||
|
|
||||||
tStart := time.Now()
|
startedAt := time.Now()
|
||||||
|
deadline := startedAt.Add(timeout)
|
||||||
conn.SetReadDeadline((time.Now().Add(timeout)))
|
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
|
||||||
|
deadline = d
|
||||||
recv := make([]byte, 1024)
|
}
|
||||||
res.RecvCount, err = conn.Read(recv)
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||||
|
return res, wrapError(err, "ping set read deadline")
|
||||||
if err != nil {
|
|
||||||
return res, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tEnd := time.Now()
|
doneCh := make(chan struct{})
|
||||||
res.Duration = tEnd.Sub(tStart)
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = conn.SetReadDeadline(time.Now())
|
||||||
|
case <-doneCh:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(doneCh)
|
||||||
|
|
||||||
return res, err
|
recv := make([]byte, icmpReadBufSz)
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(recv)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return res, wrapError(ctx.Err(), "ping context done")
|
||||||
|
}
|
||||||
|
return res, wrapError(err, "ping read reply")
|
||||||
|
}
|
||||||
|
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
|
||||||
|
res.RecvCount = n
|
||||||
|
res.Duration = time.Since(startedAt)
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSum(data []byte) uint16 {
|
func checkSum(data []byte) uint16 {
|
||||||
var (
|
return pingcore.Checksum(data)
|
||||||
sum uint32
|
|
||||||
length int = len(data)
|
|
||||||
index int
|
|
||||||
)
|
|
||||||
for length > 1 {
|
|
||||||
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
|
||||||
index += 2
|
|
||||||
length -= 2
|
|
||||||
}
|
|
||||||
if length > 0 {
|
|
||||||
sum += uint32(data[index])
|
|
||||||
}
|
|
||||||
sum += (sum >> 16)
|
|
||||||
|
|
||||||
return uint16(^sum)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PingResult struct {
|
func marshalICMP(icmp ICMP) []byte {
|
||||||
Duration time.Duration
|
return pingcore.Marshal(icmp)
|
||||||
RecvCount int
|
|
||||||
RemoteIP string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
|
||||||
var res PingResult
|
return pingcore.MarshalPacket(icmp, payload)
|
||||||
ipAddr, err := net.ResolveIPAddr("ip", ip)
|
}
|
||||||
|
|
||||||
|
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
||||||
|
return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func candidateICMPOffsets(packet []byte, family int) []int {
|
||||||
|
return pingcore.CandidateICMPOffsets(packet, family)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupOffsets(offsets []int) []int {
|
||||||
|
return pingcore.DedupOffsets(offsets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
|
||||||
|
if ip == nil {
|
||||||
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
return pingSocketSpec{
|
||||||
|
network: "ip4:icmp",
|
||||||
|
family: 4,
|
||||||
|
requestType: icmpTypeEchoRequestV4,
|
||||||
|
replyType: icmpTypeEchoReplyV4,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip16 := ip.To16(); ip16 != nil {
|
||||||
|
return pingSocketSpec{
|
||||||
|
network: "ip6:ipv6-icmp",
|
||||||
|
family: 6,
|
||||||
|
requestType: icmpTypeEchoRequestV6,
|
||||||
|
replyType: icmpTypeEchoReplyV6,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
|
||||||
|
if sourceIP == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if sourceIP.To16() == nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
|
||||||
|
}
|
||||||
|
if family == 4 && sourceIP.To4() == nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
|
||||||
|
}
|
||||||
|
if family == 6 && sourceIP.To4() != nil {
|
||||||
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
|
||||||
|
}
|
||||||
|
return &net.IPAddr{IP: sourceIP}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
||||||
|
targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return nil, err
|
||||||
}
|
}
|
||||||
icmp := getICMP(uint16(seq))
|
if len(targets) == 0 {
|
||||||
return sendICMPRequest(icmp, ipAddr, timeout)
|
return nil, ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
return targets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
||||||
for i := 0; i < retryLimit; i++ {
|
return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
|
||||||
_, err := Ping(ip, 29, timeout)
|
}
|
||||||
|
|
||||||
|
func normalizePingDialError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
if errors.Is(err, os.ErrPermission) ||
|
||||||
|
strings.Contains(msg, "operation not permitted") ||
|
||||||
|
strings.Contains(msg, "permission denied") {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(msg, "unknown network") ||
|
||||||
|
strings.Contains(msg, "protocol not available") ||
|
||||||
|
strings.Contains(msg, "address family not supported by protocol") ||
|
||||||
|
strings.Contains(msg, "socket type not supported") {
|
||||||
|
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapError(err, "ping dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
|
||||||
|
out := PingOptions{
|
||||||
|
Count: defaultCount,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Interval: 0,
|
||||||
|
PayloadSize: 0,
|
||||||
|
}
|
||||||
|
if opts != nil {
|
||||||
|
out = *opts
|
||||||
|
if out.Count == 0 {
|
||||||
|
out.Count = defaultCount
|
||||||
|
}
|
||||||
|
if out.Timeout == 0 {
|
||||||
|
out.Timeout = defaultTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if out.Count < 0 {
|
||||||
|
return out, fmt.Errorf("ping count must be >= 0")
|
||||||
|
}
|
||||||
|
if out.Timeout <= 0 {
|
||||||
|
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
|
||||||
|
}
|
||||||
|
if out.Interval < 0 {
|
||||||
|
return out, fmt.Errorf("ping interval must be >= 0")
|
||||||
|
}
|
||||||
|
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
|
||||||
|
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
|
||||||
|
}
|
||||||
|
if out.SourceIP != nil && out.SourceIP.To16() == nil {
|
||||||
|
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
|
||||||
|
var res PingResult
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return res, wrapError(err, "ping context done")
|
||||||
|
}
|
||||||
|
|
||||||
|
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
|
||||||
|
if err != nil {
|
||||||
|
return res, wrapError(err, "resolve ping target")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := pingPayload(opts.PayloadSize)
|
||||||
|
var lastErr error
|
||||||
|
for _, target := range targets {
|
||||||
|
spec, err := socketSpecForIP(target.IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
|
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
|
||||||
|
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, ErrPingPermissionDenied) {
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return res, wrapError(lastErr, "ping all resolved targets failed")
|
||||||
|
}
|
||||||
|
return res, ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
// PingWithContext sends one ICMP echo request with context cancel support.
|
||||||
|
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
|
||||||
|
opts, err := normalizePingOptions(&PingOptions{
|
||||||
|
Count: 1,
|
||||||
|
Timeout: timeout,
|
||||||
|
}, 1, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return PingResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.Deadline.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
return pingOnceWithOptions(ctx, host, seq, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping sends one ICMP echo request.
|
||||||
|
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
||||||
|
return PingWithContext(context.Background(), ip, seq, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pingable checks host reachability with retry options.
|
||||||
|
func Pingable(host string, opts *PingOptions) (bool, error) {
|
||||||
|
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if !cfg.Deadline.IsZero() {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for index := 0; index < cfg.Count; index++ {
|
||||||
|
_, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if index < cfg.Count-1 && cfg.Interval > 0 {
|
||||||
|
timer := time.NewTimer(cfg.Interval)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return false, wrapError(ctx.Err(), "pingable context done")
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = ErrPingNoResolvedTarget
|
||||||
|
}
|
||||||
|
return false, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIpPingable keeps backward-compatible bool-only behavior.
|
||||||
|
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
||||||
|
if retryLimit <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ok, _ := Pingable(ip, &PingOptions{
|
||||||
|
Count: retryLimit,
|
||||||
|
Timeout: timeout,
|
||||||
|
})
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,214 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildICMPPacket(typ uint8, identifier, seq uint16) []byte {
|
||||||
|
icmp := ICMP{
|
||||||
|
Type: typ,
|
||||||
|
Code: 0,
|
||||||
|
CheckSum: 0,
|
||||||
|
Identifier: identifier,
|
||||||
|
SequenceNum: seq,
|
||||||
|
}
|
||||||
|
buf := marshalICMPPacket(icmp, nil)
|
||||||
|
cs := checkSum(buf)
|
||||||
|
buf[2] = byte(cs >> 8)
|
||||||
|
buf[3] = byte(cs)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextPingIdentifierChanges(t *testing.T) {
|
||||||
|
id1 := nextPingIdentifier()
|
||||||
|
id2 := nextPingIdentifier()
|
||||||
|
if id1 == id2 {
|
||||||
|
t.Fatalf("identifier should change between calls: %d == %d", id1, id2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv4(t *testing.T) {
|
||||||
|
identifier := uint16(0x1234)
|
||||||
|
seq := uint16(0x0102)
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv4 echo reply to match")
|
||||||
|
}
|
||||||
|
if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) {
|
||||||
|
t.Fatal("mismatched sequence should not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) {
|
||||||
|
identifier := uint16(0x1111)
|
||||||
|
seq := uint16(0x2222)
|
||||||
|
|
||||||
|
ipHeader := make([]byte, 20)
|
||||||
|
ipHeader[0] = 0x45 // version=4, ihl=5
|
||||||
|
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
|
||||||
|
packet := append(ipHeader, reply...)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv4 packet with header to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) {
|
||||||
|
identifier := uint16(0xabcd)
|
||||||
|
seq := uint16(0x00ff)
|
||||||
|
|
||||||
|
ipv6Header := make([]byte, 40)
|
||||||
|
ipv6Header[0] = 0x60 // version=6
|
||||||
|
|
||||||
|
reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq)
|
||||||
|
packet := append(ipv6Header, reply...)
|
||||||
|
|
||||||
|
if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) {
|
||||||
|
t.Fatal("expected IPv6 packet with header to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingInvalidTimeout(t *testing.T) {
|
||||||
|
_, err := Ping("127.0.0.1", 1, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-positive timeout")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrPingInvalidTimeout) {
|
||||||
|
t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsIPPingableInvalidRetry(t *testing.T) {
|
||||||
|
if IsIpPingable("127.0.0.1", time.Millisecond, 0) {
|
||||||
|
t.Fatal("retryLimit=0 should return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSocketSpecForIP(t *testing.T) {
|
||||||
|
v4, err := socketSpecForIP(net.ParseIP("127.0.0.1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v4 error: %v", err)
|
||||||
|
}
|
||||||
|
if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 {
|
||||||
|
t.Fatalf("unexpected v4 spec: %+v", v4)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6, err := socketSpecForIP(net.ParseIP("::1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v6 error: %v", err)
|
||||||
|
}
|
||||||
|
if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 {
|
||||||
|
t.Fatalf("unexpected v6 spec: %+v", v6)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = socketSpecForIP(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nil ip")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrInvalidIP) {
|
||||||
|
t.Fatalf("expected ErrInvalidIP, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePingTargetsLiteral(t *testing.T) {
|
||||||
|
v4, err := resolvePingTargets("127.0.0.1", false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v4 resolve error: %v", err)
|
||||||
|
}
|
||||||
|
if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil {
|
||||||
|
t.Fatalf("unexpected v4 targets: %+v", v4)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6, err := resolvePingTargets("::1", false, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected v6 resolve error: %v", err)
|
||||||
|
}
|
||||||
|
if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil {
|
||||||
|
t.Fatalf("unexpected v6 targets: %+v", v6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePingDialError(t *testing.T) {
|
||||||
|
perr := normalizePingDialError(errors.New("socket: operation not permitted"))
|
||||||
|
if !errors.Is(perr, ErrPingPermissionDenied) {
|
||||||
|
t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr)
|
||||||
|
}
|
||||||
|
|
||||||
|
uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp"))
|
||||||
|
if !errors.Is(uerr, ErrPingProtocolUnsupported) {
|
||||||
|
t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderPingTargets(t *testing.T) {
|
||||||
|
targets := []*net.IPAddr{
|
||||||
|
{IP: net.ParseIP("::1")},
|
||||||
|
{IP: net.ParseIP("127.0.0.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
v4First := orderPingTargets(targets, true, false)
|
||||||
|
if v4First[0].IP.To4() == nil {
|
||||||
|
t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
v6First := orderPingTargets(targets, false, true)
|
||||||
|
if v6First[0].IP.To4() != nil {
|
||||||
|
t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizePingOptions(t *testing.T) {
|
||||||
|
opts, err := normalizePingOptions(nil, 3, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if opts.Count != 3 || opts.Timeout != 2*time.Second {
|
||||||
|
t.Fatalf("unexpected defaults: %+v", opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative count")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for too large payload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingWithContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected canceled error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context.Canceled, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingableInvalidOptions(t *testing.T) {
|
||||||
|
_, err := Pingable("127.0.0.1", &PingOptions{Count: -1})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected invalid count error")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = Pingable("127.0.0.1", &PingOptions{Interval: -1})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected invalid interval error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
|
||||||
|
tlsReqInfo := make(chan struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}, 1)
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tlsReqInfo <- struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}{
|
||||||
|
host: r.Host,
|
||||||
|
sni: r.TLS.ServerName,
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyServer := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
targetHost := "proxy-custom-ip.test"
|
||||||
|
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
|
||||||
|
req := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetProxy(proxyServer.URL).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
targets := proxyServer.Targets()
|
||||||
|
if len(targets) != 1 {
|
||||||
|
t.Fatalf("connect targets=%v; want 1 target", targets)
|
||||||
|
}
|
||||||
|
gotConnectTarget := targets[0]
|
||||||
|
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
|
||||||
|
if gotConnectTarget != wantConnectTarget {
|
||||||
|
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotTLS := <-tlsReqInfo
|
||||||
|
wantHost := net.JoinHostPort(targetHost, port)
|
||||||
|
if gotTLS.host != wantHost {
|
||||||
|
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
|
||||||
|
}
|
||||||
|
if gotTLS.sni != targetHost {
|
||||||
|
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
|
||||||
|
tlsReqInfo := make(chan struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}, 1)
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tlsReqInfo <- struct {
|
||||||
|
host string
|
||||||
|
sni string
|
||||||
|
}{
|
||||||
|
host: r.Host,
|
||||||
|
sni: r.TLS.ServerName,
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetHost := "custom-ip-direct.test"
|
||||||
|
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
|
||||||
|
req := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
gotTLS := <-tlsReqInfo
|
||||||
|
wantHost := net.JoinHostPort(targetHost, port)
|
||||||
|
if gotTLS.host != wantHost {
|
||||||
|
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
|
||||||
|
}
|
||||||
|
if gotTLS.sni != targetHost {
|
||||||
|
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,331 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connectProxyServer struct {
|
||||||
|
*httptest.Server
|
||||||
|
mu sync.Mutex
|
||||||
|
targets []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.Listener = listener
|
||||||
|
server.Start()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
server.Listener = listener
|
||||||
|
server.StartTLS()
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
testT, ok := t.(*testing.T)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("newTrustedIPv4TLSServer requires *testing.T")
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName)
|
||||||
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("X509KeyPair: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(certPEM) {
|
||||||
|
t.Fatal("AppendCertsFromPEM returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewUnstartedServer(handler)
|
||||||
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4: %v", err)
|
||||||
|
}
|
||||||
|
server.Listener = listener
|
||||||
|
server.TLS = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
server.StartTLS()
|
||||||
|
return server, pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split host port: %v", err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://%s:%s", host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
proxy := &connectProxyServer{}
|
||||||
|
if dialTarget == nil {
|
||||||
|
dialTarget = func(target string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
http.Error(w, "connect required", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.mu.Lock()
|
||||||
|
proxy.targets = append(proxy.targets, r.Host)
|
||||||
|
proxy.mu.Unlock()
|
||||||
|
|
||||||
|
targetConn, err := dialTarget(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatal("proxy response writer is not a hijacker")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, rw, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("hijack proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("write connect response: %v", err)
|
||||||
|
}
|
||||||
|
if err := rw.Flush(); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("flush connect response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayProxyConns(clientConn, targetConn)
|
||||||
|
}))
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *connectProxyServer) Targets() []string {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return append([]string(nil), p.targets...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type socks5ProxyServer struct {
|
||||||
|
ln net.Listener
|
||||||
|
addr string
|
||||||
|
dial func(target string) (net.Conn, error)
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mu sync.Mutex
|
||||||
|
targets []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if dialTarget == nil {
|
||||||
|
dialTarget = func(target string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen tcp4 socks5: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &socks5ProxyServer{
|
||||||
|
ln: ln,
|
||||||
|
addr: ln.Addr().String(),
|
||||||
|
dial: dialTarget,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer proxy.wg.Done()
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-proxy.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy.wg.Add(1)
|
||||||
|
go func(c net.Conn) {
|
||||||
|
defer proxy.wg.Done()
|
||||||
|
proxy.handleConn(t, c)
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) URL() string {
|
||||||
|
return "socks5://" + p.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) Targets() []string {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return append([]string(nil), p.targets...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) Close() {
|
||||||
|
close(p.stopCh)
|
||||||
|
_ = p.ln.Close()
|
||||||
|
p.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
closeConn := true
|
||||||
|
defer func() {
|
||||||
|
if closeConn {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
header := make([]byte, 2)
|
||||||
|
if _, err := io.ReadFull(conn, header); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if header[0] != 0x05 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
methods := make([]byte, int(header[1]))
|
||||||
|
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := conn.Write([]byte{0x05, 0x00}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqHeader := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(conn, reqHeader); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host, err := readSOCKS5Addr(conn, reqHeader[3])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
portBytes := make([]byte, 2)
|
||||||
|
if _, err := io.ReadFull(conn, portBytes); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes)))
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.targets = append(p.targets, target)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
targetConn, err := p.dial(target)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
closeConn = false
|
||||||
|
relayProxyConns(conn, targetConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) {
|
||||||
|
switch atyp {
|
||||||
|
case 0x01:
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return net.IP(buf).String(), nil
|
||||||
|
case 0x03:
|
||||||
|
var size [1]byte
|
||||||
|
if _, err := io.ReadFull(r, size[:]); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
buf := make([]byte, int(size[0]))
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(buf), nil
|
||||||
|
case 0x04:
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return net.IP(buf).String(), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported atyp: %d", atyp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func relayProxyConns(left, right net.Conn) {
|
||||||
|
var once sync.Once
|
||||||
|
closeBoth := func() {
|
||||||
|
_ = left.Close()
|
||||||
|
_ = right.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(left, right)
|
||||||
|
once.Do(closeBoth)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(right, left)
|
||||||
|
once.Do(closeBoth)
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestProxy(t *testing.T) {
|
||||||
|
// Create a proxy server
|
||||||
|
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Proxy received the request
|
||||||
|
w.Header().Set("X-Proxied", "true")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("proxied"))
|
||||||
|
}))
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
// Note: This is a simplified test. Real proxy testing requires more setup
|
||||||
|
req := NewSimpleRequest("http://example.com", "GET").
|
||||||
|
SetProxy(proxyServer.URL)
|
||||||
|
|
||||||
|
// Just verify the proxy is set in config
|
||||||
|
if req.config.Network.Proxy != proxyServer.URL {
|
||||||
|
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientLevelProxy(t *testing.T) {
|
||||||
|
proxyURL := "http://proxy.example.com:8080"
|
||||||
|
client := NewClientNoErr(WithProxy(proxyURL))
|
||||||
|
|
||||||
|
req, _ := client.NewRequest("http://example.com", "GET")
|
||||||
|
if req.config.Network.Proxy != proxyURL {
|
||||||
|
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestLevelProxyOverride(t *testing.T) {
|
||||||
|
clientProxy := "http://client-proxy.com:8080"
|
||||||
|
requestProxy := "http://request-proxy.com:8080"
|
||||||
|
|
||||||
|
client := NewClientNoErr(WithProxy(clientProxy))
|
||||||
|
req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy))
|
||||||
|
|
||||||
|
// Request level should override client level
|
||||||
|
if req.config.Network.Proxy != requestProxy {
|
||||||
|
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 识别头
|
|
||||||
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
|
|
||||||
|
|
||||||
// MsgQueue 为基本的信息单位
|
|
||||||
type MsgQueue struct {
|
|
||||||
ID uint16
|
|
||||||
Msg []byte
|
|
||||||
Conn interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StarQueue 为流数据中的消息队列分发
|
|
||||||
type StarQueue struct {
|
|
||||||
count int64
|
|
||||||
Encode bool
|
|
||||||
Reserve uint16
|
|
||||||
Msgid uint16
|
|
||||||
MsgPool chan MsgQueue
|
|
||||||
UnFinMsg sync.Map
|
|
||||||
LastID int //= -1
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
duration time.Duration
|
|
||||||
EncodeFunc func([]byte) []byte
|
|
||||||
DecodeFunc func([]byte) []byte
|
|
||||||
//restoreMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewQueueCtx(ctx context.Context, count int64) *StarQueue {
|
|
||||||
var que StarQueue
|
|
||||||
que.Encode = false
|
|
||||||
que.count = count
|
|
||||||
que.MsgPool = make(chan MsgQueue, count)
|
|
||||||
if ctx == nil {
|
|
||||||
que.ctx, que.cancel = context.WithCancel(context.Background())
|
|
||||||
} else {
|
|
||||||
que.ctx, que.cancel = context.WithCancel(ctx)
|
|
||||||
}
|
|
||||||
que.duration = 0
|
|
||||||
return &que
|
|
||||||
}
|
|
||||||
func NewQueueWithCount(count int64) *StarQueue {
|
|
||||||
return NewQueueCtx(nil, count)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQueue 建立一个新消息队列
|
|
||||||
func NewQueue() *StarQueue {
|
|
||||||
return NewQueueWithCount(32)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint32ToByte 4位uint32转byte
|
|
||||||
func Uint32ToByte(src uint32) []byte {
|
|
||||||
res := make([]byte, 4)
|
|
||||||
res[3] = uint8(src)
|
|
||||||
res[2] = uint8(src >> 8)
|
|
||||||
res[1] = uint8(src >> 16)
|
|
||||||
res[0] = uint8(src >> 24)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint32 byte转4位uint32
|
|
||||||
func ByteToUint32(src []byte) uint32 {
|
|
||||||
var res uint32
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uint16ToByte 2位uint16转byte
|
|
||||||
func Uint16ToByte(src uint16) []byte {
|
|
||||||
res := make([]byte, 2)
|
|
||||||
res[1] = uint8(src)
|
|
||||||
res[0] = uint8(src >> 8)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ByteToUint16 用于byte转uint16
|
|
||||||
func ByteToUint16(src []byte) uint16 {
|
|
||||||
var res uint16
|
|
||||||
buffer := bytes.NewBuffer(src)
|
|
||||||
binary.Read(buffer, binary.BigEndian, &res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildMessage 生成编码后的信息用于发送
|
|
||||||
func (que *StarQueue) BuildMessage(src []byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
que.Msgid++
|
|
||||||
if que.Encode {
|
|
||||||
src = que.EncodeFunc(src)
|
|
||||||
}
|
|
||||||
length := uint32(len(src))
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(que.Msgid))
|
|
||||||
buff.Write(src)
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildHeader 生成编码后的Header用于发送
|
|
||||||
func (que *StarQueue) BuildHeader(length uint32) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
que.Msgid++
|
|
||||||
buff.Write(header)
|
|
||||||
buff.Write(Uint32ToByte(length))
|
|
||||||
buff.Write(Uint16ToByte(que.Msgid))
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
type unFinMsg struct {
|
|
||||||
ID uint16
|
|
||||||
LengthRecv uint32
|
|
||||||
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
|
|
||||||
HeaderMsg []byte
|
|
||||||
RecvMsg []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (que *StarQueue) push2list(msg MsgQueue) {
|
|
||||||
que.MsgPool <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseMessage 用于解析收到的msg信息
|
|
||||||
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMessage 用于解析收到的msg信息
|
|
||||||
func (que *StarQueue) parseMessage(msg []byte, conn interface{}) error {
|
|
||||||
tmp, ok := que.UnFinMsg.Load(conn)
|
|
||||||
if ok { //存在未完成的信息
|
|
||||||
lastMsg := tmp.(*unFinMsg)
|
|
||||||
headerLen := len(lastMsg.HeaderMsg)
|
|
||||||
if headerLen < 14 { //未完成头标题
|
|
||||||
//传输的数据不能填充header头
|
|
||||||
if len(msg) < 14-headerLen {
|
|
||||||
//加入header头并退出
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
//获取14字节完整的header
|
|
||||||
header := msg[0 : 14-headerLen]
|
|
||||||
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
|
|
||||||
//检查收到的header是否为认证header
|
|
||||||
//若不是,丢弃并重新来过
|
|
||||||
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
//获得本数据包长度
|
|
||||||
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
|
|
||||||
//获得本数据包ID
|
|
||||||
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
|
|
||||||
//存入列表
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
msg = msg[14-headerLen:]
|
|
||||||
if uint32(len(msg)) < lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if uint32(len(msg)) >= lastMsg.LengthRecv {
|
|
||||||
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
|
|
||||||
if que.Encode {
|
|
||||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
msg = msg[lastMsg.LengthRecv:]
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
//que.restoreMu.Lock()
|
|
||||||
que.push2list(storeMsg)
|
|
||||||
//que.restoreMu.Unlock()
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
|
|
||||||
if lastID < 0 {
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
if len(msg) >= lastID {
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
|
|
||||||
if que.Encode {
|
|
||||||
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
|
|
||||||
}
|
|
||||||
storeMsg := MsgQueue{
|
|
||||||
ID: lastMsg.ID,
|
|
||||||
Msg: lastMsg.RecvMsg,
|
|
||||||
Conn: conn,
|
|
||||||
}
|
|
||||||
que.push2list(storeMsg)
|
|
||||||
que.UnFinMsg.Delete(conn)
|
|
||||||
if len(msg) == lastID {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
msg = msg[lastID:]
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
|
|
||||||
que.UnFinMsg.Store(conn, lastMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var start int
|
|
||||||
if start = searchHeader(msg); start == -1 {
|
|
||||||
return errors.New("data format error")
|
|
||||||
}
|
|
||||||
msg = msg[start:]
|
|
||||||
lastMsg := unFinMsg{}
|
|
||||||
que.UnFinMsg.Store(conn, &lastMsg)
|
|
||||||
return que.parseMessage(msg, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkHeader(msg []byte) bool {
|
|
||||||
if len(msg) != 8 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
if v != header[k] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func searchHeader(msg []byte) int {
|
|
||||||
if len(msg) < 8 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
for k, v := range msg {
|
|
||||||
find := 0
|
|
||||||
if v == header[0] {
|
|
||||||
for k2, v2 := range header {
|
|
||||||
if msg[k+k2] == v2 {
|
|
||||||
find++
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if find == 8 {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func bytesMerge(src ...[]byte) []byte {
|
|
||||||
var buff bytes.Buffer
|
|
||||||
for _, v := range src {
|
|
||||||
buff.Write(v)
|
|
||||||
}
|
|
||||||
return buff.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore 获取收到的信息
|
|
||||||
func (que *StarQueue) Restore() (MsgQueue, error) {
|
|
||||||
if que.duration.Seconds() == 0 {
|
|
||||||
que.duration = 86400 * time.Second
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-que.ctx.Done():
|
|
||||||
return MsgQueue{}, errors.New("Stoped By External Function Call")
|
|
||||||
case <-time.After(que.duration):
|
|
||||||
if que.duration != 0 {
|
|
||||||
return MsgQueue{}, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
case data, ok := <-que.MsgPool:
|
|
||||||
if !ok {
|
|
||||||
return MsgQueue{}, os.ErrClosed
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreOne 获取收到的一个信息
|
|
||||||
//兼容性修改
|
|
||||||
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
|
|
||||||
return que.Restore()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop 立即停止Restore
|
|
||||||
func (que *StarQueue) Stop() {
|
|
||||||
que.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreDuration Restore最大超时时间
|
|
||||||
func (que *StarQueue) RestoreDuration(tm time.Duration) {
|
|
||||||
que.duration = tm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (que *StarQueue) RestoreChan() <-chan MsgQueue {
|
|
||||||
return que.MsgPool
|
|
||||||
}
|
|
||||||
-42
@@ -1,42 +0,0 @@
|
|||||||
package starnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_QueSpeed(t *testing.T) {
|
|
||||||
que := NewQueueWithCount(0)
|
|
||||||
stop := make(chan struct{}, 1)
|
|
||||||
que.RestoreDuration(time.Second * 10)
|
|
||||||
var count int64
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
//fmt.Println(count)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
_, err := que.RestoreOne()
|
|
||||||
if err == nil {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
cp := 0
|
|
||||||
stoped := time.After(time.Second * 10)
|
|
||||||
data := que.BuildMessage([]byte("hello"))
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stoped:
|
|
||||||
fmt.Println(count, cp)
|
|
||||||
stop <- struct{}{}
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
que.ParseMessage(data, "lala")
|
|
||||||
cp++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestQuery(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
query := r.URL.Query()
|
||||||
|
result := make(map[string][]string)
|
||||||
|
for k, v := range query {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(result)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
AddQuery("name", "John").
|
||||||
|
AddQuery("age", "30").
|
||||||
|
AddQuery("tags", "go").
|
||||||
|
AddQuery("tags", "http")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string][]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
|
||||||
|
if len(result["name"]) != 1 || result["name"][0] != "John" {
|
||||||
|
t.Errorf("name = %v; want [John]", result["name"])
|
||||||
|
}
|
||||||
|
if len(result["tags"]) != 2 {
|
||||||
|
t.Errorf("tags length = %v; want 2", len(result["tags"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetQuery(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
query := r.URL.Query()
|
||||||
|
w.Write([]byte(query.Get("key")))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetQuery("key", "value1").
|
||||||
|
SetQuery("key", "value2") // Should overwrite
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != "value2" {
|
||||||
|
t.Errorf("query value = %v; want value2", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestDeleteQuery(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
query := r.URL.Query()
|
||||||
|
result := make(map[string]string)
|
||||||
|
for k := range query {
|
||||||
|
result[k] = query.Get(k)
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(result)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
AddQuery("keep", "yes").
|
||||||
|
AddQuery("delete", "no").
|
||||||
|
DeleteQuery("delete")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
resp.Body().JSON(&result)
|
||||||
|
|
||||||
|
if _, exists := result["delete"]; exists {
|
||||||
|
t.Error("delete query should not exist")
|
||||||
|
}
|
||||||
|
if result["keep"] != "yes" {
|
||||||
|
t.Errorf("keep = %v; want yes", result["keep"])
|
||||||
|
}
|
||||||
|
}
|
||||||
+607
@@ -0,0 +1,607 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Request HTTP 请求
|
||||||
|
type Request struct {
|
||||||
|
ctx context.Context
|
||||||
|
execCtx context.Context // 执行时的 context(注入了配置)
|
||||||
|
cancel context.CancelFunc
|
||||||
|
url string
|
||||||
|
method string
|
||||||
|
err error // 累积的错误
|
||||||
|
|
||||||
|
config *RequestConfig
|
||||||
|
client *Client
|
||||||
|
httpClient *http.Client
|
||||||
|
httpReq *http.Request
|
||||||
|
retry *retryPolicy
|
||||||
|
traceHooks *TraceHooks
|
||||||
|
traceRecorder *TraceRecorder
|
||||||
|
traceRun *TraceRecorder
|
||||||
|
lastTraceSummary *TraceSummary
|
||||||
|
traceState *traceState
|
||||||
|
|
||||||
|
applied bool // 是否已应用配置
|
||||||
|
doRaw bool // 是否使用原始请求(不修改)
|
||||||
|
autoFetch bool // 是否自动获取响应体
|
||||||
|
|
||||||
|
rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request
|
||||||
|
rawTemplate *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeContext(ctx context.Context) context.Context {
|
||||||
|
if ctx != nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) {
|
||||||
|
if httpReq == nil {
|
||||||
|
return nil, fmt.Errorf("http request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned := httpReq.Clone(normalizeContext(ctx))
|
||||||
|
switch {
|
||||||
|
case httpReq.Body == nil || httpReq.Body == http.NoBody:
|
||||||
|
cloned.Body = httpReq.Body
|
||||||
|
case httpReq.GetBody != nil:
|
||||||
|
body, err := httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return cloned, wrapError(err, "clone raw request body")
|
||||||
|
}
|
||||||
|
cloned.Body = body
|
||||||
|
default:
|
||||||
|
return cloned, fmt.Errorf("cannot clone raw request with non-replayable body")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rawBaseRequest() *http.Request {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.rawTemplate != nil {
|
||||||
|
return r.rawTemplate
|
||||||
|
}
|
||||||
|
return r.httpReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) invalidatePreparedState() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
r.execCtx = nil
|
||||||
|
r.traceRun = nil
|
||||||
|
r.traceState = nil
|
||||||
|
r.httpClient = nil
|
||||||
|
|
||||||
|
wasApplied := r.applied
|
||||||
|
r.applied = false
|
||||||
|
if !wasApplied || r.doRaw {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil {
|
||||||
|
r.err = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildPreparedRequestBase() error {
|
||||||
|
if r == nil || r.doRaw {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ctx := r.ctx
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "rebuild http request")
|
||||||
|
}
|
||||||
|
r.httpReq = httpReq
|
||||||
|
r.syncRequestHost()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildRawRequestBase() error {
|
||||||
|
if r == nil || !r.doRaw {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
baseReq := r.rawBaseRequest()
|
||||||
|
rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx))
|
||||||
|
if err != nil && baseReq != nil && baseReq == r.httpReq {
|
||||||
|
r.httpReq = baseReq.WithContext(normalizeContext(r.ctx))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rawReq != nil {
|
||||||
|
r.httpReq = rawReq
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) rebuildExecutionRequestBase() error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
r.execCtx = nil
|
||||||
|
r.traceState = nil
|
||||||
|
r.applied = false
|
||||||
|
if r.doRaw {
|
||||||
|
return r.rebuildRawRequestBase()
|
||||||
|
}
|
||||||
|
return r.rebuildPreparedRequestBase()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRequest 创建新请求(内部使用)
|
||||||
|
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
|
if method == "" {
|
||||||
|
method = http.MethodGet
|
||||||
|
}
|
||||||
|
method = strings.ToUpper(method)
|
||||||
|
|
||||||
|
// 创建 http.Request
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "create http request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化配置
|
||||||
|
config := &RequestConfig{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
DialTimeout: DefaultDialTimeout,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
},
|
||||||
|
Headers: make(http.Header),
|
||||||
|
Queries: make(map[string][]string),
|
||||||
|
Body: BodyConfig{
|
||||||
|
FormData: make(map[string][]string),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置默认 User-Agent
|
||||||
|
config.Headers.Set("User-Agent", DefaultUserAgent)
|
||||||
|
|
||||||
|
// POST 请求默认 Content-Type
|
||||||
|
if method == http.MethodPost {
|
||||||
|
config.Headers.Set("Content-Type", ContentTypeFormURLEncoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &Request{
|
||||||
|
ctx: ctx,
|
||||||
|
url: urlStr,
|
||||||
|
method: method,
|
||||||
|
config: config,
|
||||||
|
httpReq: httpReq,
|
||||||
|
autoFetch: DefaultFetchRespBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用选项
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
if err := opt(req); err != nil {
|
||||||
|
req.err = err
|
||||||
|
return req, nil // 不返回错误,累积到 req.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequest 创建新请求
|
||||||
|
func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
req, err := newRequest(context.Background(), url, method, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, req.err
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRequestWithContext 创建新请求(带 context)
|
||||||
|
func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
|
||||||
|
req, err := newRequest(ctx, url, method, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 新增
|
||||||
|
if req.err != nil {
|
||||||
|
return nil, req.err
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
|
||||||
|
func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
|
||||||
|
req, err := newRequest(context.Background(), url, method, opts...)
|
||||||
|
if err != nil {
|
||||||
|
// 返回一个带错误的请求
|
||||||
|
return &Request{
|
||||||
|
ctx: context.Background(),
|
||||||
|
url: url,
|
||||||
|
method: method,
|
||||||
|
err: err,
|
||||||
|
config: &RequestConfig{
|
||||||
|
Headers: make(http.Header),
|
||||||
|
Queries: make(map[string][]string),
|
||||||
|
Body: BodyConfig{
|
||||||
|
FormData: make(map[string][]string),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
|
||||||
|
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
|
req, err := newRequest(ctx, url, method, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return &Request{
|
||||||
|
ctx: ctx,
|
||||||
|
url: url,
|
||||||
|
method: method,
|
||||||
|
err: err,
|
||||||
|
config: &RequestConfig{
|
||||||
|
Headers: make(http.Header),
|
||||||
|
Queries: make(map[string][]string),
|
||||||
|
Body: BodyConfig{
|
||||||
|
FormData: make(map[string][]string),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone 克隆请求
|
||||||
|
func (r *Request) Clone() *Request {
|
||||||
|
cloned := &Request{
|
||||||
|
ctx: r.ctx,
|
||||||
|
url: r.url,
|
||||||
|
method: r.method,
|
||||||
|
err: r.err,
|
||||||
|
config: r.config.Clone(),
|
||||||
|
client: r.client,
|
||||||
|
httpClient: r.httpClient,
|
||||||
|
retry: cloneRetryPolicy(r.retry),
|
||||||
|
traceHooks: r.traceHooks,
|
||||||
|
traceRecorder: r.traceRecorder,
|
||||||
|
applied: false, // 重置应用状态
|
||||||
|
doRaw: r.doRaw,
|
||||||
|
autoFetch: r.autoFetch,
|
||||||
|
|
||||||
|
rawSourceExternal: r.rawSourceExternal,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新创建 http.Request
|
||||||
|
if !r.doRaw {
|
||||||
|
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
|
||||||
|
} else {
|
||||||
|
rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx)
|
||||||
|
cloned.rawTemplate = rawTemplate
|
||||||
|
cloned.httpReq = rawTemplate
|
||||||
|
if err != nil && cloned.err == nil {
|
||||||
|
cloned.err = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err 获取累积的错误
|
||||||
|
func (r *Request) Err() error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context 获取 context
|
||||||
|
func (r *Request) Context() context.Context {
|
||||||
|
return r.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContext 设置 context
|
||||||
|
func (r *Request) SetContext(ctx context.Context) *Request {
|
||||||
|
return r.applyMutation(mutateContext(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method 获取 HTTP 方法
|
||||||
|
func (r *Request) Method() string {
|
||||||
|
return r.method
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMethod 设置 HTTP 方法
|
||||||
|
func (r *Request) SetMethod(method string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
method = strings.ToUpper(method)
|
||||||
|
if !validMethod(method) {
|
||||||
|
r.err = wrapError(ErrInvalidMethod, "method: %s", method)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
r.method = method
|
||||||
|
if r.httpReq != nil {
|
||||||
|
r.httpReq.Method = method
|
||||||
|
}
|
||||||
|
if r.doRaw && r.rawTemplate != nil {
|
||||||
|
r.rawTemplate.Method = method
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL 获取 URL
|
||||||
|
func (r *Request) URL() string {
|
||||||
|
return r.url
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetURL 设置 URL
|
||||||
|
func (r *Request) SetURL(urlStr string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.doRaw {
|
||||||
|
r.err = fmt.Errorf("cannot set URL when using raw request")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
r.err = wrapError(ErrInvalidURL, "url: %s", urlStr)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
r.url = urlStr
|
||||||
|
u.Host = removeEmptyPort(u.Host)
|
||||||
|
r.httpReq.URL = u
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) effectiveRequestHost() string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if r.config != nil && r.config.Host != "" {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
|
if r.httpReq != nil && r.httpReq.URL != nil {
|
||||||
|
return removeEmptyPort(r.httpReq.URL.Host)
|
||||||
|
}
|
||||||
|
if r.url == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(r.url)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return removeEmptyPort(u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) syncRequestHost() {
|
||||||
|
if r == nil || r.httpReq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.httpReq.Host = r.effectiveRequestHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawRequest 获取底层 http.Request
|
||||||
|
func (r *Request) RawRequest() *http.Request {
|
||||||
|
if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied {
|
||||||
|
return r.rawTemplate
|
||||||
|
}
|
||||||
|
return r.httpReq
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRawRequest 设置底层 http.Request(启用原始模式)
|
||||||
|
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
|
||||||
|
return r.applyMutation(mutateRawRequest(httpReq))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableRawMode 启用原始模式(不修改请求)
|
||||||
|
func (r *Request) EnableRawMode() *Request {
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.doRaw = true
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableRawMode 禁用原始模式
|
||||||
|
func (r *Request) DisableRawMode() *Request {
|
||||||
|
if !r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.rawSourceExternal {
|
||||||
|
r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.doRaw = false
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoFetch 设置是否自动获取响应体
|
||||||
|
func (r *Request) SetAutoFetch(auto bool) *Request {
|
||||||
|
r.autoFetch = auto
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient 获取底层 http.Client(只读)
|
||||||
|
func (r *Request) HTTPClient() (*http.Client, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.httpClient != nil {
|
||||||
|
return r.httpClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果还没构建,先准备
|
||||||
|
if err := r.prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.httpClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client 获取关联的 Client(只读)
|
||||||
|
func (r *Request) Client() *Client {
|
||||||
|
return r.client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do 执行请求
|
||||||
|
func (r *Request) Do() (*Response, error) {
|
||||||
|
// 检查累积的错误
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.startTraceExecution()
|
||||||
|
|
||||||
|
var (
|
||||||
|
resp *Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if r.hasRetryPolicy() {
|
||||||
|
resp, err = r.doWithRetry()
|
||||||
|
} else {
|
||||||
|
resp, err = r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.finishTraceExecution(resp)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) doOnce() (*Response, error) {
|
||||||
|
if err := r.rebuildExecutionRequestBase(); err != nil {
|
||||||
|
return nil, wrapError(err, "rebuild execution request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备请求
|
||||||
|
if err := r.prepare(); err != nil {
|
||||||
|
return nil, wrapError(err, "prepare request")
|
||||||
|
}
|
||||||
|
if r.traceRun != nil {
|
||||||
|
r.traceRun.observePreparedRequest(r.httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
httpResp, err := r.httpClient.Do(r.httpReq)
|
||||||
|
if err != nil {
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
return &Response{
|
||||||
|
Response: &http.Response{},
|
||||||
|
request: r,
|
||||||
|
httpClient: r.httpClient,
|
||||||
|
body: &Body{},
|
||||||
|
}, wrapError(err, "do request")
|
||||||
|
}
|
||||||
|
if r.traceRun != nil {
|
||||||
|
r.traceRun.observeResponse(httpResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody := httpResp.Body
|
||||||
|
if r.cancel != nil {
|
||||||
|
rawBody = &cancelReadCloser{
|
||||||
|
ReadCloser: httpResp.Body,
|
||||||
|
cancel: r.cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建响应
|
||||||
|
resp := &Response{
|
||||||
|
Response: httpResp,
|
||||||
|
request: r,
|
||||||
|
httpClient: r.httpClient,
|
||||||
|
cancel: r.cancel,
|
||||||
|
body: &Body{
|
||||||
|
raw: rawBody,
|
||||||
|
maxBytes: r.config.MaxRespBodyBytes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r.cancel = nil
|
||||||
|
|
||||||
|
// 自动获取响应体
|
||||||
|
if r.autoFetch {
|
||||||
|
if err := resp.body.readAll(); err != nil {
|
||||||
|
_ = resp.Close()
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 发送 GET 请求
|
||||||
|
func (r *Request) Get() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodGet).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post 发送 POST 请求
|
||||||
|
func (r *Request) Post() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodPost).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 发送 PUT 请求
|
||||||
|
func (r *Request) Put() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodPut).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 发送 DELETE 请求
|
||||||
|
func (r *Request) Delete() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodDelete).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head 发送 HEAD 请求
|
||||||
|
func (r *Request) Head() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodHead).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patch 发送 PATCH 请求
|
||||||
|
func (r *Request) Patch() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodPatch).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options 发送 OPTIONS 请求
|
||||||
|
func (r *Request) Options() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodOptions).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace 发送 TRACE 请求
|
||||||
|
func (r *Request) Trace() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodTrace).Do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 发送 CONNECT 请求
|
||||||
|
func (r *Request) Connect() (*Response, error) {
|
||||||
|
return r.SetMethod(http.MethodConnect).Do()
|
||||||
|
}
|
||||||
+220
@@ -0,0 +1,220 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetBody 设置请求体(字节)
|
||||||
|
func (r *Request) SetBody(body []byte) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
setBytesBodyConfig(&r.config.Body, body)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBodyReader 设置请求体(Reader)。
|
||||||
|
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
|
||||||
|
func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
setReaderBodyConfig(&r.config.Body, reader)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBodyString 设置请求体(字符串)
|
||||||
|
func (r *Request) SetBodyString(body string) *Request {
|
||||||
|
return r.SetBody([]byte(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJSON 设置 JSON 请求体
|
||||||
|
func (r *Request) SetJSON(v interface{}) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
r.err = wrapError(err, "marshal json")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.SetContentType(ContentTypeJSON).SetBody(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFormData 设置表单数据(覆盖)
|
||||||
|
func (r *Request) SetFormData(data map[string][]string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
setFormBodyConfig(&r.config.Body, data)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFormData 添加表单数据
|
||||||
|
func (r *Request) AddFormData(key, value string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
ensureFormMode(&r.config.Body)
|
||||||
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFormDataMap 批量添加表单数据
|
||||||
|
func (r *Request) AddFormDataMap(data map[string]string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
ensureFormMode(&r.config.Body)
|
||||||
|
for key, value := range data {
|
||||||
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFile 添加文件(从路径)
|
||||||
|
func (r *Request) AddFile(formName, filePath string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: stat.Name(),
|
||||||
|
FilePath: filePath,
|
||||||
|
FileSize: stat.Size(),
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFileWithName 添加文件(指定文件名)
|
||||||
|
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: fileName,
|
||||||
|
FilePath: filePath,
|
||||||
|
FileSize: stat.Size(),
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFileWithType 添加文件(指定 MIME 类型)
|
||||||
|
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: stat.Name(),
|
||||||
|
FilePath: filePath,
|
||||||
|
FileSize: stat.Size(),
|
||||||
|
FileType: fileType,
|
||||||
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFileStream 添加文件流
|
||||||
|
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
if reader == nil {
|
||||||
|
r.err = ErrNilReader
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: fileName,
|
||||||
|
FileData: reader,
|
||||||
|
FileSize: size,
|
||||||
|
FileType: ContentTypeOctetStream,
|
||||||
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
|
||||||
|
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
if reader == nil {
|
||||||
|
r.err = ErrNilReader
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureMultipartMode(&r.config.Body)
|
||||||
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||||||
|
FormName: formName,
|
||||||
|
FileName: fileName,
|
||||||
|
FileData: reader,
|
||||||
|
FileSize: size,
|
||||||
|
FileType: fileType,
|
||||||
|
})
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// SetBasicAuth 设置 Basic 认证
|
||||||
|
func (r *Request) SetBasicAuth(username, password string) *Request {
|
||||||
|
return r.applyMutation(mutateBasicAuth(username, password))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentLength 设置 Content-Length
|
||||||
|
func (r *Request) SetContentLength(length int64) *Request {
|
||||||
|
return r.applyMutation(mutateContentLength(length))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
|
||||||
|
// 警告:启用后会将整个 body 读入内存
|
||||||
|
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
|
||||||
|
return r.applyMutation(mutateAutoCalcContentLength(auto))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTransport 设置自定义 Transport
|
||||||
|
func (r *Request) SetTransport(transport *http.Transport) *Request {
|
||||||
|
return r.applyMutation(mutateTransport(transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUploadProgress 设置文件上传进度回调
|
||||||
|
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
|
||||||
|
return r.applyMutation(mutateUploadProgress(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
|
||||||
|
return r.applyMutation(mutateMaxRespBodyBytes(maxBytes))
|
||||||
|
}
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com/path", http.MethodPost).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
AddQuery("q", "v").
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if err := r.Context().Err(); err != nil {
|
||||||
|
t.Fatalf("request context already done: %v", err)
|
||||||
|
}
|
||||||
|
if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" {
|
||||||
|
t.Fatalf("header values=%v", values)
|
||||||
|
}
|
||||||
|
if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" {
|
||||||
|
t.Fatalf("query values=%v", values)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
if string(body) != "payload" {
|
||||||
|
t.Fatalf("body=%q", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
n := atomic.AddInt32(&attempts, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp1, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
if err := resp1.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp2, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&attempts); got != 2 {
|
||||||
|
t.Fatalf("attempts=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodGet).
|
||||||
|
SetRawRequest(rawReq).
|
||||||
|
SetProxy("http://proxy.example:8080").
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTLSServerName("override.example")
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := req.execCtx.Value(ctxKeyRequestContext)
|
||||||
|
rc, ok := raw.(*RequestContext)
|
||||||
|
if !ok || rc == nil {
|
||||||
|
t.Fatalf("expected request context, got %#v", raw)
|
||||||
|
}
|
||||||
|
if rc.Proxy != "http://proxy.example:8080" {
|
||||||
|
t.Fatalf("proxy=%q", rc.Proxy)
|
||||||
|
}
|
||||||
|
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
|
||||||
|
t.Fatalf("custom ip=%v", rc.CustomIP)
|
||||||
|
}
|
||||||
|
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
|
||||||
|
t.Fatalf("tls config=%#v", rc.TLSConfig)
|
||||||
|
}
|
||||||
|
if rc.TLSServerName != "override.example" {
|
||||||
|
t.Fatalf("tls server name=%q", rc.TLSServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetFormDataOverridesBytesBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetBodyString("stale").
|
||||||
|
SetFormData(map[string][]string{"k": []string{"v"}})
|
||||||
|
|
||||||
|
if req.config.Body.Mode != bodyModeForm {
|
||||||
|
t.Fatalf("body mode=%v", req.config.Body.Mode)
|
||||||
|
}
|
||||||
|
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 {
|
||||||
|
t.Fatalf("unexpected stale body state: %#v", req.config.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := req.httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "k=v" {
|
||||||
|
t.Fatalf("body=%q; want k=v", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
filePath := filepath.Join(tmpDir, "payload.txt")
|
||||||
|
if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil {
|
||||||
|
t.Fatalf("WriteFile() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetJSON(map[string]string{"old": "json-only"}).
|
||||||
|
AddFile("file", filePath)
|
||||||
|
|
||||||
|
if req.config.Body.Mode != bodyModeMultipart {
|
||||||
|
t.Fatalf("body mode=%v", req.config.Body.Mode)
|
||||||
|
}
|
||||||
|
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil {
|
||||||
|
t.Fatalf("unexpected stale simple body state: %#v", req.config.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(req.httpReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(data), "file-body") {
|
||||||
|
t.Fatalf("multipart body missing file content: %q", string(data))
|
||||||
|
}
|
||||||
|
if strings.Contains(string(data), "json-only") {
|
||||||
|
t.Fatalf("multipart body still contains stale json: %q", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isHostHeaderKey(key string) bool {
|
||||||
|
return http.CanonicalHeaderKey(key) == "Host"
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRequestHostConfig(config *RequestConfig, host string) {
|
||||||
|
if config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if config.Headers == nil {
|
||||||
|
config.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
config.Host = host
|
||||||
|
if host == "" {
|
||||||
|
config.Headers.Del("Host")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Headers.Set("Host", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeader 设置 Header(覆盖)
|
||||||
|
func (r *Request) SetHeader(key, value string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.SetHost(value)
|
||||||
|
}
|
||||||
|
r.config.Headers.Set(key, value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeader 添加 Header
|
||||||
|
func (r *Request) AddHeader(key, value string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.SetHost(value)
|
||||||
|
}
|
||||||
|
r.config.Headers.Add(key, value)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeaders 设置所有 Headers(覆盖)
|
||||||
|
func (r *Request) SetHeaders(headers http.Header) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Headers = cloneHeader(headers)
|
||||||
|
r.config.Host = r.config.Headers.Get("Host")
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHeaders 批量添加 Headers
|
||||||
|
func (r *Request) AddHeaders(headers map[string]string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
for k, v := range headers {
|
||||||
|
if isHostHeaderKey(k) {
|
||||||
|
setRequestHostConfig(r.config, v)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r.config.Headers.Add(k, v)
|
||||||
|
}
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteHeader 删除 Header
|
||||||
|
func (r *Request) DeleteHeader(key string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
setRequestHostConfig(r.config, "")
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Headers.Del(key)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHeader 获取 Header
|
||||||
|
func (r *Request) GetHeader(key string) string {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
|
return r.config.Headers.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers 获取所有 Headers
|
||||||
|
func (r *Request) Headers() http.Header {
|
||||||
|
if r == nil || r.config == nil {
|
||||||
|
return make(http.Header)
|
||||||
|
}
|
||||||
|
return cloneHeader(r.config.Headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHost 设置请求 Host 头覆盖。
|
||||||
|
func (r *Request) SetHost(host string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
setRequestHostConfig(r.config, host)
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host 获取显式 Host 覆盖。
|
||||||
|
func (r *Request) Host() string {
|
||||||
|
if r.config != nil && r.config.Host != "" {
|
||||||
|
return r.config.Host
|
||||||
|
}
|
||||||
|
if r.httpReq != nil {
|
||||||
|
return r.httpReq.Host
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContentType 设置 Content-Type
|
||||||
|
func (r *Request) SetContentType(contentType string) *Request {
|
||||||
|
return r.SetHeader("Content-Type", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserAgent 设置 User-Agent
|
||||||
|
func (r *Request) SetUserAgent(userAgent string) *Request {
|
||||||
|
return r.SetHeader("User-Agent", userAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReferer 设置 Referer
|
||||||
|
func (r *Request) SetReferer(referer string) *Request {
|
||||||
|
return r.SetHeader("Referer", referer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBearerToken 设置 Bearer Token
|
||||||
|
func (r *Request) SetBearerToken(token string) *Request {
|
||||||
|
return r.SetHeader("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCookie 添加 Cookie
|
||||||
|
func (r *Request) AddCookie(cookie *http.Cookie) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie))
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSimpleCookie 添加简单 Cookie(path 为 /)
|
||||||
|
func (r *Request) AddSimpleCookie(name, value string) *Request {
|
||||||
|
return r.AddCookie(&http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCookieKV 添加 Cookie(指定 path)
|
||||||
|
func (r *Request) AddCookieKV(name, value, path string) *Request {
|
||||||
|
return r.AddCookie(&http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCookies 设置所有 Cookies(覆盖)
|
||||||
|
func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Cookies = cloneCookies(cookies)
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCookies 批量添加 Cookies
|
||||||
|
func (r *Request) AddCookies(cookies map[string]string) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.doRaw {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
for name, value := range cookies {
|
||||||
|
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cookies 获取所有 Cookies
|
||||||
|
func (r *Request) Cookies() []*http.Cookie {
|
||||||
|
if r == nil || r.config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cloneCookies(r.config.Cookies)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetHeaders 重置所有 Headers
|
||||||
|
func (r *Request) ResetHeaders() *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Headers = make(http.Header)
|
||||||
|
r.config.Host = ""
|
||||||
|
r.syncRequestHost()
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetCookies 重置所有 Cookies
|
||||||
|
func (r *Request) ResetCookies() *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.config.Cookies = []*http.Cookie{}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestConvenienceMethods(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Method", r.Method)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
do func(r *Request) (*Response, error)
|
||||||
|
}{
|
||||||
|
{name: "Head", method: http.MethodHead, do: (*Request).Head},
|
||||||
|
{name: "Patch", method: http.MethodPatch, do: (*Request).Patch},
|
||||||
|
{name: "Options", method: http.MethodOptions, do: (*Request).Options},
|
||||||
|
{name: "Trace", method: http.MethodTrace, do: (*Request).Trace},
|
||||||
|
{name: "Connect", method: http.MethodConnect, do: (*Request).Connect},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := NewSimpleRequest(server.URL, http.MethodGet)
|
||||||
|
resp, err := tt.do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := resp.Header.Get("X-Method"); got != tt.method {
|
||||||
|
t.Fatalf("method=%s want=%s", got, tt.method)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// applyMultipartBody 应用 multipart 请求体
|
||||||
|
func (r *Request) applyMultipartBody(execCtx context.Context) error {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
writer := multipart.NewWriter(pw)
|
||||||
|
|
||||||
|
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
r.httpReq.Body = pr
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer pw.Close()
|
||||||
|
defer writer.Close()
|
||||||
|
|
||||||
|
for key, values := range r.config.Body.FormData {
|
||||||
|
for _, value := range values {
|
||||||
|
if err := writer.WriteField(key, value); err != nil {
|
||||||
|
pw.CloseWithError(wrapError(err, "write form field"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range r.config.Body.Files {
|
||||||
|
if err := r.writeFile(execCtx, writer, file); err != nil {
|
||||||
|
pw.CloseWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeFile 写入文件到 multipart writer
|
||||||
|
func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error {
|
||||||
|
part, err := writer.CreateFormFile(file.FormName, file.FileName)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "create form file")
|
||||||
|
}
|
||||||
|
|
||||||
|
var reader io.Reader
|
||||||
|
if file.FileData != nil {
|
||||||
|
reader = file.FileData
|
||||||
|
} else if file.FilePath != "" {
|
||||||
|
f, err := os.Open(file.FilePath)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "open file")
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
reader = f
|
||||||
|
} else {
|
||||||
|
return ErrNilReader
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "copy file data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestMutation func(*Request) error
|
||||||
|
|
||||||
|
func (r *Request) applyMutation(mutation requestMutation) *Request {
|
||||||
|
if r == nil || r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if err := mutation(r); err != nil {
|
||||||
|
r.err = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.invalidatePreparedState()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOptFromMutation(mutation requestMutation) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return mutation(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomIPs(ips []string) error {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if net.ParseIP(ip) == nil {
|
||||||
|
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCustomDNS(dnsServers []string) error {
|
||||||
|
for _, dns := range dnsServers {
|
||||||
|
if net.ParseIP(dns) == nil {
|
||||||
|
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProxyURL(proxy string) (*url.URL, error) {
|
||||||
|
if proxy == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "parse proxy url")
|
||||||
|
}
|
||||||
|
if proxyURL.Scheme == "" {
|
||||||
|
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
|
||||||
|
}
|
||||||
|
if proxyURL.Host == "" {
|
||||||
|
return nil, fmt.Errorf("proxy host is required: %s", proxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proxyURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTimeout(timeout time.Duration) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.Timeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDialTimeout(timeout time.Duration) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.DialTimeout = timeout
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateProxy(proxy string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if _, err := parseProxyURL(proxy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.Network.Proxy = proxy
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Network.DialFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.Config = tlsConfig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTLSServerName(serverName string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.ServerName = serverName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.traceHooks = hooks
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTraceRecorder(recorder *TraceRecorder) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.traceRecorder = recorder
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSkipTLSVerify(skip bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.TLS.SkipVerify = skip
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateCustomIP(ips []string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomIPs(ips); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomIP = cloneStringSlice(ips)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddCustomIP(ip string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomIPs([]string{ip}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateCustomDNS(dnsServers []string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomDNS(dnsServers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddCustomDNS(dns string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if err := validateCustomDNS([]string{dns}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.DNS.LookupFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateBasicAuth(username, password string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.BasicAuth = [2]string{username, password}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateContentLength(length int64) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.ContentLength = length
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAutoCalcContentLength(auto bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if r.doRaw {
|
||||||
|
return fmt.Errorf("cannot set auto calc content length in raw mode")
|
||||||
|
}
|
||||||
|
r.config.AutoCalcContentLength = auto
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateTransport(transport *http.Transport) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Transport = transport
|
||||||
|
r.config.CustomTransport = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.UploadProgress = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAutoFetch(auto bool) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.autoFetch = auto
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if maxBytes < 0 {
|
||||||
|
return fmt.Errorf("max response body bytes must be >= 0")
|
||||||
|
}
|
||||||
|
r.config.MaxRespBodyBytes = maxBytes
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateContext(ctx context.Context) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
ctx = normalizeContext(ctx)
|
||||||
|
r.ctx = ctx
|
||||||
|
if r.doRaw && r.rawTemplate != nil {
|
||||||
|
r.rawTemplate = r.rawTemplate.WithContext(ctx)
|
||||||
|
}
|
||||||
|
if r.httpReq != nil {
|
||||||
|
r.httpReq = r.httpReq.WithContext(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateRawRequest(httpReq *http.Request) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
if httpReq == nil {
|
||||||
|
return fmt.Errorf("httpReq cannot be nil")
|
||||||
|
}
|
||||||
|
r.httpReq = httpReq
|
||||||
|
r.rawTemplate = httpReq
|
||||||
|
r.ctx = normalizeContext(httpReq.Context())
|
||||||
|
r.method = httpReq.Method
|
||||||
|
if httpReq.URL != nil {
|
||||||
|
r.url = httpReq.URL.String()
|
||||||
|
}
|
||||||
|
r.doRaw = true
|
||||||
|
r.rawSourceExternal = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddQuery(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSetQuery(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries[key] = []string{value}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateSetQueries(queries map[string][]string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
r.config.Queries = cloneStringMapSlice(queries)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateAddQueries(queries map[string]string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
for key, value := range queries {
|
||||||
|
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDeleteQuery(key string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
delete(r.config.Queries, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mutateDeleteQueryValue(key, value string) requestMutation {
|
||||||
|
return func(r *Request) error {
|
||||||
|
values, ok := r.config.Queries[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newValues := make([]string, 0, len(values))
|
||||||
|
for _, item := range values {
|
||||||
|
if item != value {
|
||||||
|
newValues = append(newValues, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newValues) == 0 {
|
||||||
|
delete(r.config.Queries, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.config.Queries[key] = newValues
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetTimeout 设置请求总超时时间
|
||||||
|
// timeout > 0: 为本次请求注入 context 超时
|
||||||
|
// timeout = 0: 不额外设置请求总超时
|
||||||
|
// timeout < 0: 禁用 starnet 默认总超时
|
||||||
|
func (r *Request) SetTimeout(timeout time.Duration) *Request {
|
||||||
|
return r.applyMutation(mutateTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialTimeout 设置连接超时时间
|
||||||
|
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
|
||||||
|
return r.applyMutation(mutateDialTimeout(timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProxy 设置代理
|
||||||
|
func (r *Request) SetProxy(proxy string) *Request {
|
||||||
|
return r.applyMutation(mutateProxy(proxy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialFunc 设置自定义 Dial 函数
|
||||||
|
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
|
||||||
|
return r.applyMutation(mutateDialFunc(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTLSConfig 设置 TLS 配置
|
||||||
|
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
|
||||||
|
return r.applyMutation(mutateTLSConfig(tlsConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTLSServerName 设置显式 TLS ServerName/SNI。
|
||||||
|
func (r *Request) SetTLSServerName(serverName string) *Request {
|
||||||
|
return r.applyMutation(mutateTLSServerName(serverName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSkipTLSVerify 设置是否跳过 TLS 验证
|
||||||
|
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
|
||||||
|
return r.applyMutation(mutateSkipTLSVerify(skip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS)
|
||||||
|
func (r *Request) SetCustomIP(ips []string) *Request {
|
||||||
|
return r.applyMutation(mutateCustomIP(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCustomIP 添加自定义 IP
|
||||||
|
func (r *Request) AddCustomIP(ip string) *Request {
|
||||||
|
return r.applyMutation(mutateAddCustomIP(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomDNS 设置自定义 DNS 服务器
|
||||||
|
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
|
||||||
|
return r.applyMutation(mutateCustomDNS(dnsServers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCustomDNS 添加自定义 DNS 服务器
|
||||||
|
func (r *Request) AddCustomDNS(dns string) *Request {
|
||||||
|
return r.applyMutation(mutateAddCustomDNS(dns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLookupFunc 设置自定义 DNS 解析函数
|
||||||
|
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
|
||||||
|
return r.applyMutation(mutateLookupFunc(fn))
|
||||||
|
}
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
|
||||||
|
if httpReq == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
httpReq.Body = io.NopCloser(bytes.NewReader(data))
|
||||||
|
httpReq.ContentLength = int64(len(data))
|
||||||
|
httpReq.GetBody = func() (io.ReadCloser, error) {
|
||||||
|
return io.NopCloser(bytes.NewReader(data)), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearSimpleBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Bytes = nil
|
||||||
|
body.Reader = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetFormBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetMultipartBodyState(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Files = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBytesBodyConfig(body *BodyConfig, data []byte) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeBytes
|
||||||
|
body.Bytes = cloneBytes(data)
|
||||||
|
body.Reader = nil
|
||||||
|
resetFormBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeReader
|
||||||
|
body.Reader = reader
|
||||||
|
body.Bytes = nil
|
||||||
|
resetFormBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeForm
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
body.FormData = cloneStringMapSlice(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureFormMode(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
resetMultipartBodyState(body)
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
body.Mode = bodyModeForm
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureMultipartMode(body *BodyConfig) {
|
||||||
|
if body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode == bodyModeMultipart {
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if body.Mode != bodyModeForm {
|
||||||
|
clearSimpleBodyState(body)
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
body.Mode = bodyModeMultipart
|
||||||
|
if body.FormData == nil {
|
||||||
|
body.FormData = make(map[string][]string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
data := make([]byte, reader.Len())
|
||||||
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
data := make([]byte, reader.Len())
|
||||||
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyBody 应用请求体
|
||||||
|
func (r *Request) applyBody(execCtx context.Context) error {
|
||||||
|
r.httpReq.Body = nil
|
||||||
|
r.httpReq.GetBody = nil
|
||||||
|
r.httpReq.ContentLength = 0
|
||||||
|
|
||||||
|
switch r.config.Body.Mode {
|
||||||
|
case bodyModeReader:
|
||||||
|
if r.config.Body.Reader == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch reader := r.config.Body.Reader.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
|
||||||
|
case *bytes.Reader:
|
||||||
|
data, err := snapshotBytesReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "snapshot bytes reader")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
case *strings.Reader:
|
||||||
|
data, err := snapshotStringReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "snapshot strings reader")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
default:
|
||||||
|
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
||||||
|
}
|
||||||
|
switch reader := r.config.Body.Reader.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
case *bytes.Reader:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
case *strings.Reader:
|
||||||
|
r.httpReq.ContentLength = int64(reader.Len())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case bodyModeBytes:
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
|
||||||
|
return nil
|
||||||
|
case bodyModeMultipart:
|
||||||
|
return r.applyMultipartBody(execCtx)
|
||||||
|
case bodyModeForm:
|
||||||
|
values := url.Values{}
|
||||||
|
for key, items := range r.config.Body.FormData {
|
||||||
|
for _, value := range items {
|
||||||
|
values.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
encoded := values.Encode()
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTraceTLSHandshakeInfo(req *http.Request, execCtx context.Context, defaultServerName string) TraceTLSHandshakeStartInfo {
|
||||||
|
if req == nil || req.URL == nil || req.URL.Scheme != "https" {
|
||||||
|
return TraceTLSHandshakeStartInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqCtx := getRequestContext(execCtx)
|
||||||
|
info := TraceTLSHandshakeStartInfo{}
|
||||||
|
// 自定义 DialFunc 的真实落点由调用方决定,这里只在默认拨号路径下预填地址,避免 trace 元信息误导。
|
||||||
|
if reqCtx == nil || reqCtx.DialFn == nil {
|
||||||
|
info.Network = "tcp"
|
||||||
|
info.Addr = req.URL.Host
|
||||||
|
}
|
||||||
|
if reqCtx != nil {
|
||||||
|
if reqCtx.TLSConfig != nil && reqCtx.TLSConfig.ServerName != "" {
|
||||||
|
info.ServerName = reqCtx.TLSConfig.ServerName
|
||||||
|
} else if reqCtx.TLSServerName != "" {
|
||||||
|
info.ServerName = reqCtx.TLSServerName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.ServerName == "" {
|
||||||
|
if defaultServerName != "" {
|
||||||
|
info.ServerName = defaultServerName
|
||||||
|
} else {
|
||||||
|
info.ServerName = req.URL.Hostname()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare 准备请求(应用配置)
|
||||||
|
func (r *Request) prepare() (err error) {
|
||||||
|
if r.applied {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.httpReq == nil {
|
||||||
|
return fmt.Errorf("http request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
execCtx := r.ctx
|
||||||
|
if execCtx == nil {
|
||||||
|
execCtx = context.Background()
|
||||||
|
}
|
||||||
|
defaultTLSServerName := ""
|
||||||
|
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
|
||||||
|
defaultTLSServerName = r.httpReq.URL.Hostname()
|
||||||
|
}
|
||||||
|
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
|
||||||
|
|
||||||
|
var traceState *traceState
|
||||||
|
traceHooks := composeTraceHooks(r.traceHooks, traceRecorderHooks(r.traceRun))
|
||||||
|
if traceHooks != nil {
|
||||||
|
traceState = newTraceState(traceHooks)
|
||||||
|
traceState.setDefaultTLSHandshakeInfo(buildTraceTLSHandshakeInfo(r.httpReq, execCtx, defaultTLSServerName))
|
||||||
|
execCtx = withTraceState(execCtx, traceState)
|
||||||
|
if clientTrace := traceState.clientTrace(); clientTrace != nil {
|
||||||
|
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if r.config.Network.Timeout > 0 {
|
||||||
|
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil && cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if r.httpClient == nil {
|
||||||
|
r.httpClient, err = r.buildHTTPClient()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.doRaw {
|
||||||
|
if len(r.config.Queries) > 0 {
|
||||||
|
query := r.httpReq.URL.Query()
|
||||||
|
for key, values := range r.config.Queries {
|
||||||
|
for _, value := range values {
|
||||||
|
query.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.httpReq.URL.RawQuery = query.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, values := range r.config.Headers {
|
||||||
|
if isHostHeaderKey(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, value := range values {
|
||||||
|
r.httpReq.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cookie := range r.config.Cookies {
|
||||||
|
r.httpReq.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
||||||
|
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.applyBody(execCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.ContentLength > 0 {
|
||||||
|
r.httpReq.ContentLength = r.config.ContentLength
|
||||||
|
} else if r.config.ContentLength < 0 {
|
||||||
|
r.httpReq.ContentLength = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
||||||
|
data, err := io.ReadAll(r.httpReq.Body)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "read body for content length")
|
||||||
|
}
|
||||||
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.syncRequestHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.execCtx = execCtx
|
||||||
|
r.traceState = traceState
|
||||||
|
r.cancel = cancel
|
||||||
|
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
||||||
|
r.applied = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildHTTPClient 构建 HTTP Client
|
||||||
|
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
||||||
|
if r.client != nil {
|
||||||
|
return r.client.HTTPClient(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config.CustomTransport && r.config.Transport != nil {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &Transport{base: r.config.Transport},
|
||||||
|
Timeout: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultHTTPClient(), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
SetBodyString("first")
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
if _, err := req.HTTPClient(); err != nil {
|
||||||
|
t.Fatalf("HTTPClient() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetHeader("X-Test", "two").SetBodyString("second")
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().String() error: %v", err)
|
||||||
|
}
|
||||||
|
if body != "two:second" {
|
||||||
|
t.Fatalf("body=%q; want %q", body, "two:second")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusNoContent,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusNoContent,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return nil, r.Context().Err()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Close()
|
||||||
|
|
||||||
|
_, err = req.SetTimeout(10 * time.Millisecond).Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("second Do() succeeded; want timeout error")
|
||||||
|
}
|
||||||
|
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("second Do() error=%v; want timeout", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodPost)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
writer := multipart.NewWriter(pw)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(io.Discard, pr)
|
||||||
|
_ = pr.Close()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := req.writeFile(ctx, writer, RequestFile{
|
||||||
|
FormName: "file",
|
||||||
|
FileName: "payload.txt",
|
||||||
|
FileData: strings.NewReader("payload"),
|
||||||
|
FileSize: int64(len("payload")),
|
||||||
|
})
|
||||||
|
_ = writer.Close()
|
||||||
|
_ = pw.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *Request
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bytes",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes-reader",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strings-reader",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
|
||||||
|
want: "payload",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "form-data",
|
||||||
|
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
|
||||||
|
want: "k=v",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := tt.req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
if tt.req.httpReq.GetBody == nil {
|
||||||
|
t.Fatal("GetBody is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := tt.req.httpReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != tt.want {
|
||||||
|
t.Fatalf("body=%q; want %q", string(data), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type replayRoundTripper struct {
|
||||||
|
attempts int
|
||||||
|
bodies []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = req.Body.Close()
|
||||||
|
|
||||||
|
rt.attempts++
|
||||||
|
rt.bodies = append(rt.bodies, string(body))
|
||||||
|
if rt.attempts == 1 {
|
||||||
|
return nil, errors.New("first target failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := &replayRoundTripper{}
|
||||||
|
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("roundTripResolvedTargets() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if len(rt.bodies) != 2 {
|
||||||
|
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
|
||||||
|
}
|
||||||
|
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
|
||||||
|
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
|
||||||
|
SetBodyReader(strings.NewReader("payload"))
|
||||||
|
|
||||||
|
if err := req.prepare(); err != nil {
|
||||||
|
t.Fatalf("prepare() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := &replayRoundTripper{}
|
||||||
|
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
|
||||||
|
}
|
||||||
|
if len(rt.bodies) != 1 {
|
||||||
|
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
|
||||||
|
}
|
||||||
|
if rt.bodies[0] != "payload" {
|
||||||
|
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryReplayableReaderBody(t *testing.T) {
|
||||||
|
var attempts int32
|
||||||
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
||||||
|
SetBodyReader(strings.NewReader("payload")).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = r.Body.Close()
|
||||||
|
if string(body) != "payload" {
|
||||||
|
t.Fatalf("body=%q; want payload", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("retry")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&attempts); got != 2 {
|
||||||
|
t.Fatalf("attempts=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithProxyInvalidReturnsError(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewRequest() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
|
||||||
|
_, err := NewClient(WithProxy("://bad-proxy"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("NewClient() succeeded; want invalid proxy error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
// AddQuery 添加查询参数
|
||||||
|
func (r *Request) AddQuery(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateAddQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuery 设置查询参数(覆盖)
|
||||||
|
func (r *Request) SetQuery(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateSetQuery(key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQueries 设置所有查询参数(覆盖)
|
||||||
|
func (r *Request) SetQueries(queries map[string][]string) *Request {
|
||||||
|
return r.applyMutation(mutateSetQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQueries 批量添加查询参数
|
||||||
|
func (r *Request) AddQueries(queries map[string]string) *Request {
|
||||||
|
return r.applyMutation(mutateAddQueries(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteQuery 删除查询参数
|
||||||
|
func (r *Request) DeleteQuery(key string) *Request {
|
||||||
|
return r.applyMutation(mutateDeleteQuery(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteQueryValue 删除查询参数的特定值
|
||||||
|
func (r *Request) DeleteQueryValue(key, value string) *Request {
|
||||||
|
return r.applyMutation(mutateDeleteQueryValue(key, value))
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetContextNilUsesBackground(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if r.Context() == nil {
|
||||||
|
t.Fatal("request context is nil")
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.SetContext(nil).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if req.Context() == nil {
|
||||||
|
t.Fatal("request Context() is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.client = &Client{client: &http.Client{
|
||||||
|
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||||
|
if r.Context() == nil {
|
||||||
|
t.Fatal("retry request context is nil")
|
||||||
|
}
|
||||||
|
if atomic.AddInt32(&hits, 1) == 1 {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("retry")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Request: r,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
resp, err := req.
|
||||||
|
SetTimeout(DefaultTimeout).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if got := atomic.LoadInt32(&hits); got != 2 {
|
||||||
|
t.Fatalf("hits=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
rawReq.Header.Set("X-Test", "one")
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
|
||||||
|
cloned := req.Clone()
|
||||||
|
|
||||||
|
if cloned.Err() != nil {
|
||||||
|
t.Fatalf("Clone() err = %v", cloned.Err())
|
||||||
|
}
|
||||||
|
if cloned.RawRequest() == rawReq {
|
||||||
|
t.Fatal("raw request pointer reused")
|
||||||
|
}
|
||||||
|
|
||||||
|
cloned.RawRequest().Header.Set("X-Test", "two")
|
||||||
|
if rawReq.Header.Get("X-Test") != "one" {
|
||||||
|
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := cloned.RawRequest().GetBody()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetBody() error: %v", err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "payload" {
|
||||||
|
t.Fatalf("body=%q; want payload", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
|
||||||
|
rawReq := &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: mustParseURL(t, "http://example.com/upload"),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
|
||||||
|
cloned := req.Clone()
|
||||||
|
|
||||||
|
if cloned.Err() == nil {
|
||||||
|
t.Fatal("Clone() should fail for non-replayable raw body")
|
||||||
|
}
|
||||||
|
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
|
||||||
|
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("DisableRawMode() should set error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
|
||||||
|
t.Fatalf("DisableRawMode() err=%v", req.Err())
|
||||||
|
}
|
||||||
|
if !req.doRaw {
|
||||||
|
t.Fatal("request should remain in raw mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURL(t *testing.T, raw string) *url.URL {
|
||||||
|
t.Helper()
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("url.Parse() error: %v", err)
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
+172
@@ -0,0 +1,172 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSimpleRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
url string
|
||||||
|
method string
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid GET request",
|
||||||
|
url: "https://example.com",
|
||||||
|
method: "GET",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid POST request",
|
||||||
|
url: "https://example.com",
|
||||||
|
method: "POST",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid URL",
|
||||||
|
url: "://invalid",
|
||||||
|
method: "GET",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := NewRequest(tt.url, tt.method)
|
||||||
|
if tt.expectErr {
|
||||||
|
if err == nil && req.Err() == nil {
|
||||||
|
t.Errorf("NewRequest() expected error, got nil")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewRequest() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if req.Method() != strings.ToUpper(tt.method) {
|
||||||
|
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMethods(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Method", r.Method)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(r.Method))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||||
|
|
||||||
|
for _, method := range methods {
|
||||||
|
t.Run(method, func(t *testing.T) {
|
||||||
|
req := NewSimpleRequest(server.URL, method)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if method != "HEAD" {
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != method {
|
||||||
|
t.Errorf("Body = %v; want %v", body, method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetMethod(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET")
|
||||||
|
|
||||||
|
req.SetMethod("POST")
|
||||||
|
if req.Method() != "POST" {
|
||||||
|
t.Errorf("Method = %v; want POST", req.Method())
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetMethod("invalid method!")
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Error("SetMethod with invalid method should set error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestSetURL(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("https://example.com", "GET")
|
||||||
|
|
||||||
|
req.SetURL("https://newexample.com")
|
||||||
|
if req.URL() != "https://newexample.com" {
|
||||||
|
t.Errorf("URL = %v; want https://newexample.com", req.URL())
|
||||||
|
}
|
||||||
|
|
||||||
|
req2 := NewSimpleRequest("https://example.com", "GET")
|
||||||
|
req2.SetURL("://invalid")
|
||||||
|
if req2.Err() == nil {
|
||||||
|
t.Error("SetURL with invalid URL should set error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestClone(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Header.Get("X-Test") != "value" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").
|
||||||
|
SetHeader("X-Test", "value")
|
||||||
|
|
||||||
|
// 第一次请求
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
resp.Close()
|
||||||
|
|
||||||
|
// 克隆请求
|
||||||
|
cloned := req.Clone()
|
||||||
|
cloned.SetHeader("X-Extra", "extra")
|
||||||
|
|
||||||
|
// 克隆的请求应该也能成功
|
||||||
|
resp2, err := cloned.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Cloned Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
|
||||||
|
if resp2.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestContext(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected timeout error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
// SetTraceHooks 设置请求 trace 回调。
|
||||||
|
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
|
||||||
|
return r.applyMutation(mutateTraceHooks(hooks))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTraceRecorder 设置请求级 trace 摘要记录器。
|
||||||
|
// 记录器会保存最近一次已完成请求的摘要;若多个请求共享同一个记录器,则以最后一次完成的请求为准。
|
||||||
|
func (r *Request) SetTraceRecorder(recorder *TraceRecorder) *Request {
|
||||||
|
return r.applyMutation(mutateTraceRecorder(recorder))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceSummary 返回当前请求最近一次执行的 trace 摘要快照。
|
||||||
|
func (r *Request) TraceSummary() *TraceSummary {
|
||||||
|
if r == nil || r.lastTraceSummary == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
summary := cloneTraceSummary(*r.lastTraceSummary)
|
||||||
|
return &summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) startTraceExecution() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.traceRun = nil
|
||||||
|
if r.traceRecorder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.traceRun = r.traceRecorder.forkExecution()
|
||||||
|
if r.traceRun != nil {
|
||||||
|
r.traceRun.startRequest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) finishTraceExecution(resp *Response) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.traceRun == nil {
|
||||||
|
r.lastTraceSummary = nil
|
||||||
|
if resp != nil {
|
||||||
|
resp.traceSummary = nil
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := r.traceRun.Summary()
|
||||||
|
r.lastTraceSummary = cloneTraceSummaryPtr(summary)
|
||||||
|
r.traceRecorder.publishSummary(summary)
|
||||||
|
if resp != nil {
|
||||||
|
resp.traceSummary = cloneTraceSummaryPtr(summary)
|
||||||
|
}
|
||||||
|
r.traceRun = nil
|
||||||
|
}
|
||||||
+209
@@ -0,0 +1,209 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response HTTP 响应
|
||||||
|
type Response struct {
|
||||||
|
*http.Response
|
||||||
|
request *Request
|
||||||
|
httpClient *http.Client
|
||||||
|
cancel func()
|
||||||
|
body *Body
|
||||||
|
traceSummary *TraceSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
// Body 响应体
|
||||||
|
type Body struct {
|
||||||
|
raw io.ReadCloser
|
||||||
|
data []byte
|
||||||
|
consumed bool
|
||||||
|
maxBytes int64
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type cancelReadCloser struct {
|
||||||
|
io.ReadCloser
|
||||||
|
cancel func()
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cancelReadCloser) Close() error {
|
||||||
|
err := c.ReadCloser.Close()
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.cancel != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request 获取原始请求
|
||||||
|
func (r *Response) Request() *Request {
|
||||||
|
return r.request
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceSummary 获取当前响应对应的 trace 摘要快照。
|
||||||
|
func (r *Response) TraceSummary() *TraceSummary {
|
||||||
|
if r == nil || r.traceSummary == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
summary := cloneTraceSummary(*r.traceSummary)
|
||||||
|
return &summary
|
||||||
|
}
|
||||||
|
|
||||||
|
// Body 获取响应体
|
||||||
|
func (r *Response) Body() *Body {
|
||||||
|
return r.body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭响应体
|
||||||
|
func (r *Response) Close() error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.body != nil && r.body.raw != nil {
|
||||||
|
return r.body.raw.Close()
|
||||||
|
}
|
||||||
|
if r.cancel != nil {
|
||||||
|
r.cancel()
|
||||||
|
r.cancel = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWithClient 关闭响应体并关闭空闲连接
|
||||||
|
func (r *Response) CloseWithClient() error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.httpClient != nil {
|
||||||
|
r.httpClient.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
return r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// readAll 读取所有数据
|
||||||
|
func (b *Body) readAll() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.consumed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.raw == nil {
|
||||||
|
b.consumed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := io.Reader(b.raw)
|
||||||
|
if b.maxBytes > 0 {
|
||||||
|
reader = io.LimitReader(b.raw, b.maxBytes+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return wrapError(err, "read response body")
|
||||||
|
}
|
||||||
|
if b.maxBytes > 0 && int64(len(data)) > b.maxBytes {
|
||||||
|
b.consumed = true
|
||||||
|
_ = b.raw.Close()
|
||||||
|
return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.data = data
|
||||||
|
b.consumed = true
|
||||||
|
_ = b.raw.Close()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes 获取响应体字节
|
||||||
|
func (b *Body) Bytes() ([]byte, error) {
|
||||||
|
if err := b.readAll(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String 获取响应体字符串
|
||||||
|
func (b *Body) String() (string, error) {
|
||||||
|
data, err := b.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON 解析 JSON 响应
|
||||||
|
func (b *Body) JSON(v interface{}) error {
|
||||||
|
data, err := b.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reader 获取 Reader(只能调用一次)
|
||||||
|
func (b *Body) Reader() (io.ReadCloser, error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.consumed {
|
||||||
|
if b.data != nil {
|
||||||
|
// 已读取,返回缓存数据的 Reader
|
||||||
|
return io.NopCloser(bytes.NewReader(b.data)), nil
|
||||||
|
}
|
||||||
|
return nil, ErrBodyAlreadyConsumed
|
||||||
|
}
|
||||||
|
|
||||||
|
b.consumed = true
|
||||||
|
return b.raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConsumed 检查是否已消费
|
||||||
|
func (b *Body) IsConsumed() bool {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
return b.consumed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭 Body
|
||||||
|
func (b *Body) Close() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
if b.raw != nil {
|
||||||
|
return b.raw.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBytes 获取响应体字节(忽略错误,失败返回 nil)
|
||||||
|
func (b *Body) MustBytes() []byte {
|
||||||
|
data, err := b.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustString 获取响应体字符串(忽略错误,失败返回空串)
|
||||||
|
func (b *Body) MustString() string {
|
||||||
|
s, err := b.String()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal 解析 JSON 响应(兼容旧 API)
|
||||||
|
func (b *Body) Unmarshal(v interface{}) error {
|
||||||
|
return b.JSON(v)
|
||||||
|
}
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithMaxRespBodyBytes(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("123456"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL, WithMaxRespBodyBytes(4))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected request error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
_, err = resp.Body().Bytes()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body too large error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRespBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytes(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("1234"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetMaxRespBodyBytes(4).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected read error: %v", err)
|
||||||
|
}
|
||||||
|
if body != "1234" {
|
||||||
|
t.Fatalf("body=%q want=1234", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("123456"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetAutoFetch(true).
|
||||||
|
SetMaxRespBodyBytes(4).
|
||||||
|
Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected body too large error with auto fetch")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrRespBodyTooLarge) {
|
||||||
|
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMaxRespBodyBytesInvalid(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error for negative max bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithMaxRespBodyBytesInvalid(t *testing.T) {
|
||||||
|
_, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative max bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseBody(t *testing.T) {
|
||||||
|
testData := "test response data"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(testData))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
// Test String()
|
||||||
|
body, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().String() error: %v", err)
|
||||||
|
}
|
||||||
|
if body != testData {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test multiple reads (should work because body is cached)
|
||||||
|
body2, err := resp.Body().String()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Second Body().String() error: %v", err)
|
||||||
|
}
|
||||||
|
if body2 != testData {
|
||||||
|
t.Errorf("Second Body = %v; want %v", body2, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseJSON(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"name":"John","age":30}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Body().JSON(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().JSON() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Name != "John" {
|
||||||
|
t.Errorf("Name = %v; want John", result.Name)
|
||||||
|
}
|
||||||
|
if result.Age != 30 {
|
||||||
|
t.Errorf("Age = %v; want 30", result.Age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseBytes(t *testing.T) {
|
||||||
|
testData := []byte("binary data")
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testData)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, err := resp.Body().Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(body) != string(testData) {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseReader(t *testing.T) {
|
||||||
|
testData := "stream data"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(testData))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
reader, err := resp.Body().Reader()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Body().Reader() error: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(body) != testData {
|
||||||
|
t.Errorf("Body = %v; want %v", string(body), testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseAutoFetch(t *testing.T) {
|
||||||
|
testData := "auto fetch data"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(testData))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// With auto fetch
|
||||||
|
resp, err := Get(server.URL, WithAutoFetch(true))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if !resp.Body().IsConsumed() {
|
||||||
|
t.Error("Body should be consumed with auto fetch")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != testData {
|
||||||
|
t.Errorf("Body = %v; want %v", body, testData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseStatusCode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"OK", http.StatusOK},
|
||||||
|
{"Created", http.StatusCreated},
|
||||||
|
{"BadRequest", http.StatusBadRequest},
|
||||||
|
{"NotFound", http.StatusNotFound},
|
||||||
|
{"InternalServerError", http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(tt.statusCode)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,468 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RetryOpt func(*retryPolicy) error
|
||||||
|
|
||||||
|
type retryPolicy struct {
|
||||||
|
maxRetries int
|
||||||
|
baseDelay time.Duration
|
||||||
|
maxDelay time.Duration
|
||||||
|
factor float64
|
||||||
|
jitter float64
|
||||||
|
idempotentOnly bool
|
||||||
|
statuses map[int]struct{}
|
||||||
|
onError func(error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := &retryPolicy{
|
||||||
|
maxRetries: p.maxRetries,
|
||||||
|
baseDelay: p.baseDelay,
|
||||||
|
maxDelay: p.maxDelay,
|
||||||
|
factor: p.factor,
|
||||||
|
jitter: p.jitter,
|
||||||
|
idempotentOnly: p.idempotentOnly,
|
||||||
|
onError: p.onError,
|
||||||
|
}
|
||||||
|
if p.statuses != nil {
|
||||||
|
cloned.statuses = make(map[int]struct{}, len(p.statuses))
|
||||||
|
for code := range p.statuses {
|
||||||
|
cloned.statuses[code] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultRetryPolicy(max int) *retryPolicy {
|
||||||
|
return &retryPolicy{
|
||||||
|
maxRetries: max,
|
||||||
|
baseDelay: 100 * time.Millisecond,
|
||||||
|
maxDelay: 2 * time.Second,
|
||||||
|
factor: 2.0,
|
||||||
|
jitter: 0.1,
|
||||||
|
idempotentOnly: true,
|
||||||
|
statuses: map[int]struct{}{
|
||||||
|
http.StatusRequestTimeout: {},
|
||||||
|
http.StatusTooEarly: {},
|
||||||
|
http.StatusTooManyRequests: {},
|
||||||
|
http.StatusInternalServerError: {},
|
||||||
|
http.StatusBadGateway: {},
|
||||||
|
http.StatusServiceUnavailable: {},
|
||||||
|
http.StatusGatewayTimeout: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
|
||||||
|
if max < 0 {
|
||||||
|
return nil, fmt.Errorf("max retry must be >= 0")
|
||||||
|
}
|
||||||
|
if max == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := defaultRetryPolicy(max)
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := opt(policy); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRetry 为请求启用自动重试。
|
||||||
|
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||||||
|
// 以避免请求体已落地后再次发送。
|
||||||
|
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
|
||||||
|
return func(r *Request) error {
|
||||||
|
policy, err := buildRetryPolicy(max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.retry = policy
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRetry 为请求启用自动重试。
|
||||||
|
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
|
||||||
|
// 以避免请求体已落地后再次发送。
|
||||||
|
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
policy, err := buildRetryPolicy(max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
r.err = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.retry = policy
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) DisableRetry() *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
r.retry = nil
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
|
||||||
|
if r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if opt == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if r.retry == nil {
|
||||||
|
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
if err := opt(r.retry); err != nil {
|
||||||
|
r.err = err
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryJitter(ratio float64) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryJitter(ratio))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryStatuses(codes ...int) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryStatuses(codes...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
|
||||||
|
return r.applyRetryOpt(WithRetryOnError(fn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
if base < 0 {
|
||||||
|
return fmt.Errorf("retry base delay must be >= 0")
|
||||||
|
}
|
||||||
|
if max < 0 {
|
||||||
|
return fmt.Errorf("retry max delay must be >= 0")
|
||||||
|
}
|
||||||
|
if factor <= 0 {
|
||||||
|
return fmt.Errorf("retry factor must be > 0")
|
||||||
|
}
|
||||||
|
p.baseDelay = base
|
||||||
|
p.maxDelay = max
|
||||||
|
p.factor = factor
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryJitter(ratio float64) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
if ratio < 0 || ratio > 1 {
|
||||||
|
return fmt.Errorf("retry jitter ratio must be in [0,1]")
|
||||||
|
}
|
||||||
|
p.jitter = ratio
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryStatuses(codes ...int) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
statuses := make(map[int]struct{}, len(codes))
|
||||||
|
for _, code := range codes {
|
||||||
|
if code < 100 || code > 999 {
|
||||||
|
return fmt.Errorf("invalid retry status code: %d", code)
|
||||||
|
}
|
||||||
|
statuses[code] = struct{}{}
|
||||||
|
}
|
||||||
|
p.statuses = statuses
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
p.idempotentOnly = enabled
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryOnError(fn func(error) bool) RetryOpt {
|
||||||
|
return func(p *retryPolicy) error {
|
||||||
|
p.onError = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) hasRetryPolicy() bool {
|
||||||
|
return r.retry != nil && r.retry.maxRetries > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) doWithRetry() (*Response, error) {
|
||||||
|
policy := cloneRetryPolicy(r.retry)
|
||||||
|
if policy == nil || policy.maxRetries <= 0 {
|
||||||
|
return r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !policy.canRetryRequest(r) {
|
||||||
|
return r.doOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
retryCtx := normalizeContext(r.ctx)
|
||||||
|
retryCancel := func() {}
|
||||||
|
if r.config.Network.Timeout > 0 {
|
||||||
|
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
|
||||||
|
}
|
||||||
|
defer retryCancel()
|
||||||
|
|
||||||
|
maxAttempts := policy.maxRetries + 1
|
||||||
|
var lastResp *Response
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
attemptNo := attempt + 1
|
||||||
|
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
MaxAttempts: maxAttempts,
|
||||||
|
})
|
||||||
|
|
||||||
|
attemptReq, err := r.newRetryAttempt(retryCtx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "build retry attempt")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := attemptReq.doOnce()
|
||||||
|
if resp != nil {
|
||||||
|
resp.request = r
|
||||||
|
}
|
||||||
|
|
||||||
|
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
|
||||||
|
statusCode := 0
|
||||||
|
if resp != nil {
|
||||||
|
statusCode = resp.StatusCode
|
||||||
|
}
|
||||||
|
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
MaxAttempts: maxAttempts,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Err: err,
|
||||||
|
WillRetry: willRetry,
|
||||||
|
})
|
||||||
|
if !willRetry {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lastResp = resp
|
||||||
|
lastErr = err
|
||||||
|
if lastResp != nil {
|
||||||
|
_ = lastResp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := policy.nextDelay(attempt)
|
||||||
|
if delay <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
|
||||||
|
Attempt: attemptNo,
|
||||||
|
Delay: delay,
|
||||||
|
})
|
||||||
|
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
select {
|
||||||
|
case <-retryCtx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return lastResp, wrapError(retryCtx.Err(), "retry context done")
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastResp, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
|
||||||
|
attempt := r.Clone()
|
||||||
|
attempt.retry = nil
|
||||||
|
attempt.cancel = nil
|
||||||
|
attempt.applied = false
|
||||||
|
attempt.execCtx = nil
|
||||||
|
attempt.ctx = ctx
|
||||||
|
attempt.traceRun = r.traceRun
|
||||||
|
attempt.lastTraceSummary = nil
|
||||||
|
|
||||||
|
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
|
||||||
|
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
|
||||||
|
attempt.config.Network.Timeout = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if !attempt.doRaw {
|
||||||
|
attempt.httpReq = attempt.httpReq.WithContext(ctx)
|
||||||
|
return attempt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt.httpReq = raw
|
||||||
|
return attempt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) canRetryRequest(r *Request) bool {
|
||||||
|
if p.idempotentOnly && !isIdempotentMethod(r.method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isReplayableRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdempotentMethod(method string) bool {
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReplayableRequest(r *Request) bool {
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.doRaw {
|
||||||
|
if r.httpReq == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.httpReq.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.config == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return isReplayableConfiguredBody(r.config.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasReaderRequestBody(r *Request) bool {
|
||||||
|
if r == nil || r.config == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReplayableConfiguredBody(body BodyConfig) bool {
|
||||||
|
switch body.Mode {
|
||||||
|
case bodyModeReader:
|
||||||
|
return isReplayableBodyReader(body.Reader)
|
||||||
|
case bodyModeMultipart:
|
||||||
|
for _, file := range body.Files {
|
||||||
|
if file.FileData != nil || file.FilePath == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isReplayableBodyReader(reader io.Reader) bool {
|
||||||
|
switch reader.(type) {
|
||||||
|
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
|
||||||
|
if attempt >= maxAttempts-1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ctx != nil && ctx.Err() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if p.onError != nil {
|
||||||
|
return p.onError(err)
|
||||||
|
}
|
||||||
|
return isRetryableError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil || resp.Response == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := p.statuses[resp.StatusCode]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetryableError(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
if netErr.Timeout() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if netErr.Temporary() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
|
||||||
|
if p.baseDelay <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
|
||||||
|
if p.maxDelay > 0 && delay > p.maxDelay {
|
||||||
|
delay = p.maxDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.jitter <= 0 {
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
low := 1 - p.jitter
|
||||||
|
if low < 0 {
|
||||||
|
low = 0
|
||||||
|
}
|
||||||
|
high := 1 + p.jitter
|
||||||
|
scale := low + rand.Float64()*(high-low)
|
||||||
|
return time.Duration(float64(delay) * scale)
|
||||||
|
}
|
||||||
+298
@@ -0,0 +1,298 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWithRetrySmokeGet(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
resp, err := Get(s.URL,
|
||||||
|
WithRetry(2,
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 3 {
|
||||||
|
t.Fatalf("hits=%d want=3", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryResponseRequestPointerStable(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.Request() != req {
|
||||||
|
t.Fatal("response request pointer should point to original request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyReader(strings.NewReader("payload")).
|
||||||
|
SetRetry(3,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 1 {
|
||||||
|
t.Fatalf("hits=%d want=1", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello").
|
||||||
|
SetRetry(1,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt32(&hits) != 2 {
|
||||||
|
t.Fatalf("hits=%d want=2", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw")))
|
||||||
|
if rawReq.GetBody != nil {
|
||||||
|
t.Fatal("raw request GetBody should be nil in this test")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)).
|
||||||
|
SetRetry(2,
|
||||||
|
WithRetryIdempotentOnly(false),
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if atomic.LoadInt32(&hits) != 1 {
|
||||||
|
t.Fatalf("hits=%d want=1", hits)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&hits, 1)
|
||||||
|
time.Sleep(80 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetTimeout(120*time.Millisecond).
|
||||||
|
SetRetry(3,
|
||||||
|
WithRetryBackoff(0, 0, 1),
|
||||||
|
WithRetryJitter(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatalf("expected context deadline exceeded, got: %v", err)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h > 2 {
|
||||||
|
t.Fatalf("hits=%d want<=2 under tight timeout budget", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryInvalidMax(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error for negative retry max")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetrySeriesSmokeGet(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
if n <= 2 {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodGet).
|
||||||
|
SetRetry(2).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
SetRetryStatuses(http.StatusTooManyRequests)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h != 3 {
|
||||||
|
t.Fatalf("hits=%d want=3", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryIdempotentOnlyWithPost(t *testing.T) {
|
||||||
|
var hits int32
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
n := atomic.AddInt32(&hits, 1)
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
if n == 1 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(s.URL, http.MethodPost).
|
||||||
|
SetBodyString("hello").
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryIdempotentOnly(false).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&hits); h != 2 {
|
||||||
|
t.Fatalf("hits=%d want=2", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryOnErrorOverridesDefault(t *testing.T) {
|
||||||
|
var dials int32
|
||||||
|
dialErr := errors.New("dial failed")
|
||||||
|
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).
|
||||||
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
atomic.AddInt32(&dials, 1)
|
||||||
|
return nil, dialErr
|
||||||
|
}).
|
||||||
|
SetRetry(1).
|
||||||
|
SetRetryBackoff(0, 0, 1).
|
||||||
|
SetRetryJitter(0).
|
||||||
|
SetRetryOnError(func(err error) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if h := atomic.LoadInt32(&dials); h != 2 {
|
||||||
|
t.Fatalf("dial attempts=%d want=2", h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetRetryOptionRequireEnableRetry(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2)
|
||||||
|
if req.Err() == nil {
|
||||||
|
t.Fatal("expected error when setting retry options before SetRetry")
|
||||||
|
}
|
||||||
|
if !strings.Contains(req.Err().Error(), "call SetRetry first") {
|
||||||
|
t.Fatalf("unexpected error: %v", req.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
|
||||||
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split tls server addr: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstTarget := net.JoinHostPort("127.0.0.2", port)
|
||||||
|
secondTarget := net.JoinHostPort("127.0.0.1", port)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
connectTargets []string
|
||||||
|
)
|
||||||
|
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodConnect {
|
||||||
|
http.Error(w, "connect required", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
connectTargets = append(connectTargets, r.Host)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if r.Host == firstTarget {
|
||||||
|
http.Error(w, "first target failed", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetConn, err := net.Dial("tcp", r.Host)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatal("proxy response writer is not a hijacker")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, rw, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("hijack proxy conn: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("write connect response: %v", err)
|
||||||
|
}
|
||||||
|
if err := rw.Flush(); err != nil {
|
||||||
|
clientConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
t.Fatalf("flush connect response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayProxyConns(clientConn, targetConn)
|
||||||
|
}))
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
|
||||||
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetProxy(proxyServer.URL).
|
||||||
|
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(connectTargets) != 2 {
|
||||||
|
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
|
||||||
|
}
|
||||||
|
if connectTargets[0] != firstTarget {
|
||||||
|
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
|
||||||
|
}
|
||||||
|
if connectTargets[1] != secondTarget {
|
||||||
|
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
dnsStartCount int
|
||||||
|
dnsDoneCount int
|
||||||
|
lastHost string
|
||||||
|
)
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsStartCount++
|
||||||
|
lastHost = info.Host
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected dns error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
|
||||||
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
||||||
|
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if dnsStartCount != 1 {
|
||||||
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
||||||
|
}
|
||||||
|
if dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
||||||
|
}
|
||||||
|
if lastHost != "localhost" {
|
||||||
|
t.Fatalf("lastHost=%q; want localhost", lastHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestHeadersReturnsCopy(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet).
|
||||||
|
SetHeader("X-Test", "one").
|
||||||
|
SetHost("origin.example")
|
||||||
|
|
||||||
|
headers := req.Headers()
|
||||||
|
headers.Set("X-Test", "two")
|
||||||
|
headers.Set("Host", "mutated.example")
|
||||||
|
|
||||||
|
if got := req.GetHeader("X-Test"); got != "one" {
|
||||||
|
t.Fatalf("request header=%q; want one", got)
|
||||||
|
}
|
||||||
|
if got := req.Host(); got != "origin.example" {
|
||||||
|
t.Fatalf("request host=%q; want origin.example", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCookiesIsolation(t *testing.T) {
|
||||||
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
||||||
|
source := []*http.Cookie{{
|
||||||
|
Name: "session",
|
||||||
|
Value: "one",
|
||||||
|
Path: "/",
|
||||||
|
}}
|
||||||
|
|
||||||
|
req.SetCookies(source)
|
||||||
|
source[0].Value = "mutated-outside"
|
||||||
|
|
||||||
|
got := req.Cookies()
|
||||||
|
if len(got) != 1 || got[0].Value != "one" {
|
||||||
|
t.Fatalf("cookies after SetCookies=%v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got[0].Value = "mutated-copy"
|
||||||
|
if latest := req.Cookies()[0].Value; latest != "one" {
|
||||||
|
t.Fatalf("internal cookie mutated via getter, got %q", latest)
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := &http.Cookie{Name: "auth", Value: "token"}
|
||||||
|
req.ResetCookies().AddCookie(cookie)
|
||||||
|
cookie.Value = "changed"
|
||||||
|
|
||||||
|
latest := req.Cookies()
|
||||||
|
if len(latest) != 1 || latest[0].Value != "token" {
|
||||||
|
t.Fatalf("cookies after AddCookie=%v", latest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsStartCount int
|
||||||
|
var dnsDoneCount int
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
dnsStartCount++
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
dnsDoneCount++
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if dnsStartCount != 1 || dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
baseTimeout := client.HTTPClient().Timeout
|
||||||
|
|
||||||
|
_, err := client.Get(s.URL, WithTimeout(80*time.Millisecond))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected request timeout error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.HTTPClient().Timeout != baseTimeout {
|
||||||
|
t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second request should succeed with larger timeout: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first request should succeed: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp1.Close()
|
||||||
|
|
||||||
|
// 如果请求超时依赖连接级绝对 deadline,经过该等待后复用连接会出现误超时。
|
||||||
|
time.Sleep(220 * time.Millisecond)
|
||||||
|
|
||||||
|
resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second request should not be affected by previous timeout window: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp2.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen error: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
time.Sleep(180 * time.Millisecond)
|
||||||
|
_, _ = conn.Write([]byte("x"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial error: %v", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil {
|
||||||
|
t.Fatalf("set read deadline error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, err = c.Read(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected read timeout error")
|
||||||
|
}
|
||||||
|
if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
|
||||||
|
t.Fatalf("expected net timeout error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
t.Fatalf("clear read deadline error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(c, buf); err != nil {
|
||||||
|
t.Fatalf("read after clearing deadline should succeed: %v", err)
|
||||||
|
}
|
||||||
|
if string(buf) != "x" {
|
||||||
|
t.Fatalf("unexpected payload: %q", string(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestTimeout(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Should timeout
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond)
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected timeout error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should succeed
|
||||||
|
req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond)
|
||||||
|
resp, err := req2.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
resp.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestDialTimeout(t *testing.T) {
|
||||||
|
// Use a non-routable IP to test dial timeout
|
||||||
|
req := NewSimpleRequest("http://192.0.2.1:80", "GET").
|
||||||
|
SetDialTimeout(100 * time.Millisecond)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := req.Do()
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected dial timeout error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should timeout within reasonable time (not wait forever)
|
||||||
|
if elapsed > 2*time.Second {
|
||||||
|
t.Errorf("Dial timeout took too long: %v", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientTimeout(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr(WithTimeout(100 * time.Millisecond))
|
||||||
|
_, err := client.Get(server.URL)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected timeout error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
+336
@@ -0,0 +1,336 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestSkipTLSVerify(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Without skip verify (should fail)
|
||||||
|
req := NewSimpleRequest(server.URL, "GET")
|
||||||
|
_, err := req.Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected TLS error without skip verify, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// With skip verify (should succeed)
|
||||||
|
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
|
||||||
|
resp, err := req2.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() with skip verify error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
body, _ := resp.Body().String()
|
||||||
|
if body != "OK" {
|
||||||
|
t.Errorf("Body = %v; want OK", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestCustomTLSConfig(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientDefaultTLSConfig(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
client.SetDefaultSkipTLSVerify(true)
|
||||||
|
|
||||||
|
resp, err := client.Get(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestLevelTLSOverride(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Client level: skip verify = false
|
||||||
|
client := NewClientNoErr()
|
||||||
|
client.SetDefaultSkipTLSVerify(false)
|
||||||
|
|
||||||
|
// Request level: skip verify = true (should override)
|
||||||
|
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTls(t *testing.T) {
|
||||||
|
var requestCount int
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestCount++
|
||||||
|
switch requestCount {
|
||||||
|
case 1:
|
||||||
|
if r.Header.Get("Hello") != "" {
|
||||||
|
t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello"))
|
||||||
|
}
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
t.Fatalf("unexpected authorization on first request: %q", auth)
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("Hello"); got != "world" {
|
||||||
|
t.Fatalf("hello header=%q; want world", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" {
|
||||||
|
t.Fatalf("authorization=%q; want bearer token", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
localURL := httpsURLForHost(t, server, "localhost")
|
||||||
|
resp, err := NewSimpleRequest(localURL, "GET").
|
||||||
|
SetTLSConfig(&tls.Config{RootCAs: pool}).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
t.Logf("Response: %v", resp.Body().MustString())
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClient() error: %v", err)
|
||||||
|
}
|
||||||
|
resp, err = client.NewSimpleRequest(localURL, "GET",
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
WithHeader("hello", "world"),
|
||||||
|
WithContext(context.Background()),
|
||||||
|
WithBearerToken("ddddddd")).Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
t.Logf("Response: %v", resp.Body().MustString())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSWithProxyPath(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("proxied"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy(proxy.URL),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 {
|
||||||
|
t.Fatalf("proxy targets=%v; want 1 target", targets)
|
||||||
|
}
|
||||||
|
t.Log(resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSWithProxyBug(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newIPv4ConnectProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关键:使用 WithProxy 触发 needsDynamicTransport
|
||||||
|
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy(proxy.URL),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
// 修复前会报:tls: either ServerName or InsecureSkipVerify must be specified
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
|
||||||
|
t.Fatalf("proxy targets=%v", targets)
|
||||||
|
}
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更精准的复现:直接测试有问题的分支
|
||||||
|
func TestTLSDialWithoutServerName(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最小复现:只要触发 needsDynamicTransport 即可
|
||||||
|
func TestMinimalTLSBug(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialTimeout 也会触发动态 transport
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
|
||||||
|
WithDialTimeout(5*time.Second),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
// 修复前必现:tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSWithSOCKS5ProxyPath(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
proxy := newSOCKS5ProxyServer(t, nil)
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
client, err := NewClient()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET",
|
||||||
|
WithTimeout(10*time.Second),
|
||||||
|
WithProxy(proxy.URL()),
|
||||||
|
WithCustomIP([]string{"127.0.0.1"}),
|
||||||
|
WithTLSConfig(&tls.Config{RootCAs: pool}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
|
||||||
|
t.Fatalf("socks5 targets=%v", targets)
|
||||||
|
}
|
||||||
|
t.Logf("Status: %s", resp.Status)
|
||||||
|
}
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetConfigForClientFunc selects TLS config by hostname/SNI.
|
||||||
|
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
|
||||||
|
|
||||||
|
// ClientHelloMeta carries sniffed TLS metadata and connection context.
|
||||||
|
type ClientHelloMeta struct {
|
||||||
|
ServerName string
|
||||||
|
LocalAddr net.Addr
|
||||||
|
RemoteAddr net.Addr
|
||||||
|
SupportedProtos []string
|
||||||
|
SupportedVersions []uint16
|
||||||
|
CipherSuites []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a detached copy safe for callers to mutate.
|
||||||
|
func (m *ClientHelloMeta) Clone() *ClientHelloMeta {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := *m
|
||||||
|
if m.SupportedProtos != nil {
|
||||||
|
out.SupportedProtos = append([]string(nil), m.SupportedProtos...)
|
||||||
|
}
|
||||||
|
if m.SupportedVersions != nil {
|
||||||
|
out.SupportedVersions = append([]uint16(nil), m.SupportedVersions...)
|
||||||
|
}
|
||||||
|
if m.CipherSuites != nil {
|
||||||
|
out.CipherSuites = append([]uint16(nil), m.CipherSuites...)
|
||||||
|
}
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigForClientHelloFunc selects TLS config by sniffed TLS metadata.
|
||||||
|
type GetConfigForClientHelloFunc func(hello *ClientHelloMeta) (*tls.Config, error)
|
||||||
|
|
||||||
|
// ListenerConfig controls listener behavior.
|
||||||
|
type ListenerConfig struct {
|
||||||
|
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
|
||||||
|
BaseTLSConfig *tls.Config
|
||||||
|
|
||||||
|
// GetConfigForClient selects TLS config for a hostname/SNI.
|
||||||
|
// Deprecated: prefer GetConfigForClientHello for richer context.
|
||||||
|
GetConfigForClient GetConfigForClientFunc
|
||||||
|
|
||||||
|
// GetConfigForClientHello selects TLS config for sniffed TLS metadata.
|
||||||
|
GetConfigForClientHello GetConfigForClientHelloFunc
|
||||||
|
|
||||||
|
// AllowNonTLS allows plain TCP fallback.
|
||||||
|
AllowNonTLS bool
|
||||||
|
|
||||||
|
// SniffTimeout bounds protocol sniffing time. 0 means no timeout.
|
||||||
|
SniffTimeout time.Duration
|
||||||
|
|
||||||
|
// MaxClientHelloBytes limits buffered sniff data.
|
||||||
|
// If <= 0, default 64KiB.
|
||||||
|
MaxClientHelloBytes int
|
||||||
|
|
||||||
|
// Logger is optional.
|
||||||
|
Logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultListenerConfig returns a conservative default config.
|
||||||
|
func DefaultListenerConfig() ListenerConfig {
|
||||||
|
return ListenerConfig{
|
||||||
|
AllowNonTLS: false,
|
||||||
|
SniffTimeout: 5 * time.Second,
|
||||||
|
MaxClientHelloBytes: 64 * 1024,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSDefaults returns a TLS config baseline.
|
||||||
|
// Caller should set Certificates / GetCertificate as needed.
|
||||||
|
func TLSDefaults() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialConfig controls dialing behavior.
|
||||||
|
type DialConfig struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
LocalAddr net.Addr
|
||||||
|
}
|
||||||
+543
@@ -0,0 +1,543 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/starnet/internal/tlssniffercore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// replayConn replays buffered bytes first, then reads from live conn.
|
||||||
|
type replayConn struct {
|
||||||
|
reader io.Reader
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
|
||||||
|
return &replayConn{
|
||||||
|
reader: io.MultiReader(buffered, conn),
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||||
|
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
||||||
|
func (c *replayConn) Close() error { return c.conn.Close() }
|
||||||
|
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
||||||
|
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
||||||
|
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
||||||
|
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
||||||
|
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
||||||
|
|
||||||
|
// SniffResult describes protocol sniffing result.
|
||||||
|
type SniffResult struct {
|
||||||
|
IsTLS bool
|
||||||
|
ClientHello *ClientHelloMeta
|
||||||
|
Buffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sniffer detects protocol and metadata from initial bytes.
|
||||||
|
type Sniffer interface {
|
||||||
|
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSSniffer is the default sniffer implementation.
|
||||||
|
type TLSSniffer struct{}
|
||||||
|
|
||||||
|
// Sniff detects TLS and extracts SNI when possible.
|
||||||
|
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
||||||
|
res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
|
||||||
|
if err != nil {
|
||||||
|
return SniffResult{}, err
|
||||||
|
}
|
||||||
|
return convertCoreSniffResult(res), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
|
||||||
|
out := SniffResult{
|
||||||
|
IsTLS: res.IsTLS,
|
||||||
|
Buffer: res.Buffer,
|
||||||
|
}
|
||||||
|
if res.ClientHello != nil {
|
||||||
|
out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
|
||||||
|
if meta == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ClientHelloMeta{
|
||||||
|
ServerName: meta.ServerName,
|
||||||
|
LocalAddr: meta.LocalAddr,
|
||||||
|
RemoteAddr: meta.RemoteAddr,
|
||||||
|
SupportedProtos: append([]string(nil), meta.SupportedProtos...),
|
||||||
|
SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
|
||||||
|
CipherSuites: append([]uint16(nil), meta.CipherSuites...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn wraps net.Conn with lazy protocol initialization.
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
once sync.Once
|
||||||
|
initErr error
|
||||||
|
closeOnce sync.Once
|
||||||
|
|
||||||
|
isTLS bool
|
||||||
|
tlsConn *tls.Conn
|
||||||
|
plainConn net.Conn
|
||||||
|
|
||||||
|
clientHello *ClientHelloMeta
|
||||||
|
|
||||||
|
baseTLSConfig *tls.Config
|
||||||
|
getConfigForClient GetConfigForClientFunc
|
||||||
|
getConfigForClientHello GetConfigForClientHelloFunc
|
||||||
|
allowNonTLS bool
|
||||||
|
sniffer Sniffer
|
||||||
|
sniffTimeout time.Duration
|
||||||
|
maxClientHello int
|
||||||
|
logger Logger
|
||||||
|
stats *Stats
|
||||||
|
skipSniff bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
Conn: raw,
|
||||||
|
plainConn: raw,
|
||||||
|
baseTLSConfig: cfg.BaseTLSConfig,
|
||||||
|
getConfigForClient: cfg.GetConfigForClient,
|
||||||
|
getConfigForClientHello: cfg.GetConfigForClientHello,
|
||||||
|
allowNonTLS: cfg.AllowNonTLS,
|
||||||
|
sniffer: TLSSniffer{},
|
||||||
|
sniffTimeout: cfg.SniffTimeout,
|
||||||
|
maxClientHello: cfg.MaxClientHelloBytes,
|
||||||
|
logger: cfg.Logger,
|
||||||
|
stats: stats,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) init() {
|
||||||
|
c.once.Do(func() {
|
||||||
|
if c.skipSniff {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
|
||||||
|
c.isTLS = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.sniffTimeout > 0 {
|
||||||
|
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
|
||||||
|
}
|
||||||
|
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
|
||||||
|
if c.sniffTimeout > 0 {
|
||||||
|
_ = c.Conn.SetReadDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.initErr = errors.Join(ErrTLSSniffFailed, err)
|
||||||
|
c.failSniff(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.isTLS = res.IsTLS
|
||||||
|
c.clientHello = res.ClientHello
|
||||||
|
|
||||||
|
if c.isTLS {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incTLSDetected()
|
||||||
|
}
|
||||||
|
tlsCfg, errCfg := c.selectTLSConfig()
|
||||||
|
if errCfg != nil {
|
||||||
|
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
|
||||||
|
c.failTLSConfigSelection(errCfg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||||
|
c.tlsConn = tls.Server(rc, tlsCfg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incPlainDetected()
|
||||||
|
}
|
||||||
|
if !c.allowNonTLS {
|
||||||
|
c.initErr = ErrNonTLSNotAllowed
|
||||||
|
c.failPlainRejected()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failAndClose(format string, v ...interface{}) {
|
||||||
|
if c.logger != nil {
|
||||||
|
c.logger.Printf("starnet: "+format, v...)
|
||||||
|
}
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failSniff(err error) {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incSniffFailures()
|
||||||
|
}
|
||||||
|
c.failAndClose("tls sniff failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failTLSConfigSelection(err error) {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incTLSConfigFailures()
|
||||||
|
}
|
||||||
|
c.failAndClose("tls config selection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failPlainRejected() {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incPlainRejected()
|
||||||
|
}
|
||||||
|
c.failAndClose("plain tcp rejected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
||||||
|
var selected *tls.Config
|
||||||
|
if c.getConfigForClientHello != nil {
|
||||||
|
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
selected = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if selected == nil && c.getConfigForClient != nil {
|
||||||
|
cfg, err := c.getConfigForClient(c.serverName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
selected = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
|
||||||
|
if composed != nil {
|
||||||
|
return composed, nil
|
||||||
|
}
|
||||||
|
return nil, ErrNoTLSConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hostname returns sniffed SNI hostname (if any).
|
||||||
|
func (c *Conn) Hostname() string {
|
||||||
|
c.init()
|
||||||
|
return c.serverName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientHello returns sniffed TLS metadata (if any).
|
||||||
|
func (c *Conn) ClientHello() *ClientHelloMeta {
|
||||||
|
c.init()
|
||||||
|
return c.clientHello.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) serverName() string {
|
||||||
|
if c.clientHello == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.clientHello.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
||||||
|
return tlssniffercore.ComposeServerTLSConfig(base, selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyServerTLSOverrides(dst, src *tls.Config) {
|
||||||
|
tlssniffercore.ApplyServerTLSOverrides(dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) IsTLS() bool {
|
||||||
|
c.init()
|
||||||
|
return c.initErr == nil && c.isTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) TLSConn() (*tls.Conn, error) {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return nil, c.initErr
|
||||||
|
}
|
||||||
|
if !c.isTLS || c.tlsConn == nil {
|
||||||
|
return nil, ErrNotTLS
|
||||||
|
}
|
||||||
|
return c.tlsConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Read(b []byte) (int, error) {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return 0, c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS {
|
||||||
|
return c.tlsConn.Read(b)
|
||||||
|
}
|
||||||
|
return c.plainConn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Write(b []byte) (int, error) {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return 0, c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS {
|
||||||
|
return c.tlsConn.Write(b)
|
||||||
|
}
|
||||||
|
return c.plainConn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
var err error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
if c.tlsConn != nil {
|
||||||
|
err = c.tlsConn.Close()
|
||||||
|
} else {
|
||||||
|
err = c.Conn.Close()
|
||||||
|
}
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incClosed()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetDeadline(t time.Time) error {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
return c.plainConn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
return c.plainConn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
c.init()
|
||||||
|
if c.initErr != nil {
|
||||||
|
return c.initErr
|
||||||
|
}
|
||||||
|
if c.isTLS && c.tlsConn != nil {
|
||||||
|
return c.tlsConn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
return c.plainConn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listener wraps net.Listener and returns starnet.Conn from Accept.
|
||||||
|
type Listener struct {
|
||||||
|
net.Listener
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
cfg ListenerConfig
|
||||||
|
stats Stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen creates a plain listener config (no TLS detection).
|
||||||
|
func Listen(network, address string) (*Listener, error) {
|
||||||
|
ln, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = true
|
||||||
|
cfg.BaseTLSConfig = nil
|
||||||
|
cfg.GetConfigForClient = nil
|
||||||
|
return &Listener{Listener: ln, cfg: cfg}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenWithConfig creates a listener with full config.
|
||||||
|
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
|
||||||
|
ln, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenWithListenConfig creates listener using net.ListenConfig.
|
||||||
|
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
|
||||||
|
ln, err := lc.Listen(context.Background(), network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenTLS creates TLS listener from cert/key paths.
|
||||||
|
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = allowNonTLS
|
||||||
|
cfg.BaseTLSConfig = TLSDefaults()
|
||||||
|
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
return ListenWithConfig(network, address, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
|
||||||
|
out := DefaultListenerConfig()
|
||||||
|
out.AllowNonTLS = cfg.AllowNonTLS
|
||||||
|
out.SniffTimeout = cfg.SniffTimeout
|
||||||
|
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
|
||||||
|
out.BaseTLSConfig = cfg.BaseTLSConfig
|
||||||
|
out.GetConfigForClient = cfg.GetConfigForClient
|
||||||
|
out.GetConfigForClientHello = cfg.GetConfigForClientHello
|
||||||
|
out.Logger = cfg.Logger
|
||||||
|
if out.MaxClientHelloBytes <= 0 {
|
||||||
|
out.MaxClientHelloBytes = 64 * 1024
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConfig atomically replaces listener config for new accepted connections.
|
||||||
|
func (l *Listener) SetConfig(cfg ListenerConfig) {
|
||||||
|
l.mu.Lock()
|
||||||
|
l.cfg = normalizeConfig(cfg)
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config returns a copy of current config.
|
||||||
|
func (l *Listener) Config() ListenerConfig {
|
||||||
|
l.mu.RLock()
|
||||||
|
cfg := l.cfg
|
||||||
|
l.mu.RUnlock()
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns current counters snapshot.
|
||||||
|
func (l *Listener) Stats() StatsSnapshot {
|
||||||
|
return l.stats.Snapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) Accept() (net.Conn, error) {
|
||||||
|
raw, err := l.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l.stats.incAccepted()
|
||||||
|
|
||||||
|
l.mu.RLock()
|
||||||
|
cfg := l.cfg
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
return newConn(raw, cfg, &l.stats), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
|
||||||
|
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
|
||||||
|
type result struct {
|
||||||
|
c net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
ch := make(chan result, 1)
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
ch <- result{c: c, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case r := <-ch:
|
||||||
|
return r.c, r.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial creates a plain TCP starnet.Conn.
|
||||||
|
func Dial(network, address string) (*Conn, error) {
|
||||||
|
raw, err := net.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = true
|
||||||
|
cfg.BaseTLSConfig = nil
|
||||||
|
cfg.GetConfigForClient = nil
|
||||||
|
c := newConn(raw, cfg, nil)
|
||||||
|
c.isTLS = false
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialWithConfig dials with net.Dialer options.
|
||||||
|
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
|
||||||
|
d := net.Dialer{
|
||||||
|
Timeout: dc.Timeout,
|
||||||
|
LocalAddr: dc.LocalAddr,
|
||||||
|
}
|
||||||
|
raw, err := d.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = true
|
||||||
|
c := newConn(raw, cfg, nil)
|
||||||
|
c.isTLS = false
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialTLSWithConfig creates a TLS client connection wrapper.
|
||||||
|
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
|
||||||
|
d := net.Dialer{Timeout: timeout}
|
||||||
|
raw, err := d.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tc := tls.Client(raw, tlsCfg)
|
||||||
|
return &Conn{
|
||||||
|
Conn: raw,
|
||||||
|
plainConn: raw,
|
||||||
|
isTLS: true,
|
||||||
|
tlsConn: tc,
|
||||||
|
initErr: nil,
|
||||||
|
allowNonTLS: false,
|
||||||
|
skipSniff: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialTLS creates TLS client conn from cert/key paths.
|
||||||
|
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cfg := TLSDefaults()
|
||||||
|
cfg.Certificates = []tls.Certificate{cert}
|
||||||
|
return DialTLSWithConfig(network, address, cfg, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
|
||||||
|
if listener == nil {
|
||||||
|
return nil, ErrNilConn
|
||||||
|
}
|
||||||
|
return &Listener{
|
||||||
|
Listener: listener,
|
||||||
|
cfg: normalizeConfig(cfg),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
+1210
File diff suppressed because it is too large
Load Diff
+55
@@ -0,0 +1,55 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
// StatsSnapshot is a read-only copy of runtime counters.
|
||||||
|
type StatsSnapshot struct {
|
||||||
|
Accepted uint64
|
||||||
|
TLSDetected uint64
|
||||||
|
PlainDetected uint64
|
||||||
|
InitFailures uint64
|
||||||
|
SniffFailures uint64
|
||||||
|
TLSConfigFailures uint64
|
||||||
|
PlainRejected uint64
|
||||||
|
Closed uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats provides lock-free counters.
|
||||||
|
type Stats struct {
|
||||||
|
accepted uint64
|
||||||
|
tlsDetected uint64
|
||||||
|
plainDetected uint64
|
||||||
|
initFailures uint64
|
||||||
|
sniffFailures uint64
|
||||||
|
tlsConfigFailures uint64
|
||||||
|
plainRejected uint64
|
||||||
|
closed uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) }
|
||||||
|
func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
|
||||||
|
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
|
||||||
|
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
|
||||||
|
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
|
||||||
|
func (s *Stats) incSniffFailures() { atomic.AddUint64(&s.sniffFailures, 1); s.incInitFailures() }
|
||||||
|
func (s *Stats) incTLSConfigFailures() { atomic.AddUint64(&s.tlsConfigFailures, 1); s.incInitFailures() }
|
||||||
|
func (s *Stats) incPlainRejected() { atomic.AddUint64(&s.plainRejected, 1); s.incInitFailures() }
|
||||||
|
|
||||||
|
// Snapshot returns a stable view of counters.
|
||||||
|
func (s *Stats) Snapshot() StatsSnapshot {
|
||||||
|
return StatsSnapshot{
|
||||||
|
Accepted: atomic.LoadUint64(&s.accepted),
|
||||||
|
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
|
||||||
|
PlainDetected: atomic.LoadUint64(&s.plainDetected),
|
||||||
|
InitFailures: atomic.LoadUint64(&s.initFailures),
|
||||||
|
SniffFailures: atomic.LoadUint64(&s.sniffFailures),
|
||||||
|
TLSConfigFailures: atomic.LoadUint64(&s.tlsConfigFailures),
|
||||||
|
PlainRejected: atomic.LoadUint64(&s.plainRejected),
|
||||||
|
Closed: atomic.LoadUint64(&s.closed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is a minimal logging abstraction.
|
||||||
|
type Logger interface {
|
||||||
|
Printf(format string, v ...interface{})
|
||||||
|
}
|
||||||
@@ -0,0 +1,604 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type traceContextKey struct{}
|
||||||
|
|
||||||
|
// TraceHooks defines optional callbacks for network lifecycle events.
|
||||||
|
// Hooks may be called concurrently.
|
||||||
|
type TraceHooks struct {
|
||||||
|
GetConn func(TraceGetConnInfo)
|
||||||
|
GotConn func(TraceGotConnInfo)
|
||||||
|
PutIdleConn func(TracePutIdleConnInfo)
|
||||||
|
DNSStart func(TraceDNSStartInfo)
|
||||||
|
DNSDone func(TraceDNSDoneInfo)
|
||||||
|
ConnectStart func(TraceConnectStartInfo)
|
||||||
|
ConnectDone func(TraceConnectDoneInfo)
|
||||||
|
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
|
||||||
|
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
|
||||||
|
WroteHeaderField func(TraceWroteHeaderFieldInfo)
|
||||||
|
WroteHeaders func()
|
||||||
|
WroteRequest func(TraceWroteRequestInfo)
|
||||||
|
GotFirstResponseByte func()
|
||||||
|
RetryAttemptStart func(TraceRetryAttemptStartInfo)
|
||||||
|
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
|
||||||
|
RetryBackoff func(TraceRetryBackoffInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceGetConnInfo struct {
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceGotConnInfo struct {
|
||||||
|
Conn net.Conn
|
||||||
|
Reused bool
|
||||||
|
WasIdle bool
|
||||||
|
IdleTime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type TracePutIdleConnInfo struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceDNSStartInfo struct {
|
||||||
|
Host string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceDNSDoneInfo struct {
|
||||||
|
Addrs []net.IPAddr
|
||||||
|
Coalesced bool
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceConnectStartInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceConnectDoneInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceTLSHandshakeStartInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
ServerName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceTLSHandshakeDoneInfo struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
ServerName string
|
||||||
|
ConnectionState tls.ConnectionState
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceWroteHeaderFieldInfo struct {
|
||||||
|
Key string
|
||||||
|
Values []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceWroteRequestInfo struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryAttemptStartInfo struct {
|
||||||
|
Attempt int
|
||||||
|
MaxAttempts int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryAttemptDoneInfo struct {
|
||||||
|
Attempt int
|
||||||
|
MaxAttempts int
|
||||||
|
StatusCode int
|
||||||
|
Err error
|
||||||
|
WillRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceRetryBackoffInfo struct {
|
||||||
|
Attempt int
|
||||||
|
Delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type traceState struct {
|
||||||
|
hooks *TraceHooks
|
||||||
|
customTLS atomic.Uint32
|
||||||
|
manualDNSRefs atomic.Int32
|
||||||
|
defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTraceState(hooks *TraceHooks) *traceState {
|
||||||
|
if hooks == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &traceState{hooks: hooks}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.defaultTLSHandshakeInfo = info
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo {
|
||||||
|
if t == nil {
|
||||||
|
return TraceTLSHandshakeStartInfo{}
|
||||||
|
}
|
||||||
|
return t.defaultTLSHandshakeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func withTraceState(ctx context.Context, state *traceState) context.Context {
|
||||||
|
if state == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, traceContextKey{}, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTraceState(ctx context.Context) *traceState {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state, _ := ctx.Value(traceContextKey{}).(*traceState)
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) needsHTTPTrace() bool {
|
||||||
|
if t == nil || t.hooks == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h := t.hooks
|
||||||
|
return h.GetConn != nil ||
|
||||||
|
h.GotConn != nil ||
|
||||||
|
h.PutIdleConn != nil ||
|
||||||
|
h.DNSStart != nil ||
|
||||||
|
h.DNSDone != nil ||
|
||||||
|
h.ConnectStart != nil ||
|
||||||
|
h.ConnectDone != nil ||
|
||||||
|
h.TLSHandshakeStart != nil ||
|
||||||
|
h.TLSHandshakeDone != nil ||
|
||||||
|
h.WroteHeaderField != nil ||
|
||||||
|
h.WroteHeaders != nil ||
|
||||||
|
h.WroteRequest != nil ||
|
||||||
|
h.GotFirstResponseByte != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) clientTrace() *httptrace.ClientTrace {
|
||||||
|
if !t.needsHTTPTrace() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h := t.hooks
|
||||||
|
trace := &httptrace.ClientTrace{}
|
||||||
|
if h.GetConn != nil {
|
||||||
|
trace.GetConn = func(hostPort string) {
|
||||||
|
h.GetConn(TraceGetConnInfo{Addr: hostPort})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.GotConn != nil {
|
||||||
|
trace.GotConn = func(info httptrace.GotConnInfo) {
|
||||||
|
h.GotConn(TraceGotConnInfo{
|
||||||
|
Conn: info.Conn,
|
||||||
|
Reused: info.Reused,
|
||||||
|
WasIdle: info.WasIdle,
|
||||||
|
IdleTime: info.IdleTime,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.PutIdleConn != nil {
|
||||||
|
trace.PutIdleConn = func(err error) {
|
||||||
|
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.DNSStart != nil {
|
||||||
|
trace.DNSStart = func(info httptrace.DNSStartInfo) {
|
||||||
|
if t.usesManualDNS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.DNSDone != nil {
|
||||||
|
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
|
||||||
|
if t.usesManualDNS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.DNSDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: append([]net.IPAddr(nil), info.Addrs...),
|
||||||
|
Coalesced: info.Coalesced,
|
||||||
|
Err: info.Err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.ConnectStart != nil {
|
||||||
|
trace.ConnectStart = func(network, addr string) {
|
||||||
|
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.ConnectDone != nil {
|
||||||
|
trace.ConnectDone = func(network, addr string, err error) {
|
||||||
|
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.TLSHandshakeStart != nil {
|
||||||
|
trace.TLSHandshakeStart = func() {
|
||||||
|
if t.usesCustomTLS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.TLSHandshakeDone != nil {
|
||||||
|
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
|
||||||
|
if t.usesCustomTLS() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info := t.getDefaultTLSHandshakeInfo()
|
||||||
|
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||||||
|
Network: info.Network,
|
||||||
|
Addr: info.Addr,
|
||||||
|
ServerName: info.ServerName,
|
||||||
|
ConnectionState: state,
|
||||||
|
Err: err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.WroteHeaderField != nil {
|
||||||
|
trace.WroteHeaderField = func(key string, value []string) {
|
||||||
|
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
|
||||||
|
Key: key,
|
||||||
|
Values: append([]string(nil), value...),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.WroteHeaders != nil {
|
||||||
|
trace.WroteHeaders = h.WroteHeaders
|
||||||
|
}
|
||||||
|
if h.WroteRequest != nil {
|
||||||
|
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
|
||||||
|
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.GotFirstResponseByte != nil {
|
||||||
|
trace.GotFirstResponseByte = h.GotFirstResponseByte
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) markCustomTLS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.customTLS.Store(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) usesCustomTLS() bool {
|
||||||
|
if t == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.customTLS.Load() != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) beginManualDNS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.manualDNSRefs.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) endManualDNS() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.manualDNSRefs.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) usesManualDNS() bool {
|
||||||
|
if t == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.manualDNSRefs.Load() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.TLSHandshakeStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.TLSHandshakeDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.DNSStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
|
||||||
|
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hooks.DNSDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
|
||||||
|
if hooks == nil || hooks.RetryAttemptStart == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryAttemptStart(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
|
||||||
|
if hooks == nil || hooks.RetryAttemptDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryAttemptDone(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
|
||||||
|
if hooks == nil || hooks.RetryBackoff == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hooks.RetryBackoff(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks {
|
||||||
|
if recorder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return recorder.Hooks()
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceHooks(first, second *TraceHooks) *TraceHooks {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TraceHooks{
|
||||||
|
GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn),
|
||||||
|
GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn),
|
||||||
|
PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn),
|
||||||
|
DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart),
|
||||||
|
DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone),
|
||||||
|
ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart),
|
||||||
|
ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone),
|
||||||
|
TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart),
|
||||||
|
TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone),
|
||||||
|
WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField),
|
||||||
|
WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders),
|
||||||
|
WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest),
|
||||||
|
GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte),
|
||||||
|
RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart),
|
||||||
|
RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone),
|
||||||
|
RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceGetConnInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceGotConnInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TracePutIdleConnInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceDNSStartInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceDNSDoneInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceConnectStartInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceConnectDoneInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceWroteHeaderFieldInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceWroteRequestInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceRetryAttemptStartInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceRetryAttemptDoneInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func(info TraceRetryBackoffInfo) {
|
||||||
|
first(info)
|
||||||
|
second(info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeTraceSimpleHook(first, second func()) func() {
|
||||||
|
switch {
|
||||||
|
case first == nil:
|
||||||
|
return second
|
||||||
|
case second == nil:
|
||||||
|
return first
|
||||||
|
default:
|
||||||
|
return func() {
|
||||||
|
first()
|
||||||
|
second()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,550 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TraceSummary 是一次请求执行的 trace 摘要。
|
||||||
|
type TraceSummary struct {
|
||||||
|
Method string
|
||||||
|
URL string
|
||||||
|
StartedAt time.Time
|
||||||
|
ResponseAt time.Time
|
||||||
|
StatusCode int
|
||||||
|
ResponseProto string
|
||||||
|
Conn TraceConnSummary
|
||||||
|
DNS *TraceDNSSummary
|
||||||
|
DNSEvents []TraceDNSSummary
|
||||||
|
Connect []TraceConnectSummary
|
||||||
|
TLS *TraceTLSSummary
|
||||||
|
RequestWrittenAt time.Time
|
||||||
|
RequestWriteErr error
|
||||||
|
FirstResponseByteAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceConnSummary 是连接复用与套接字信息摘要。
|
||||||
|
type TraceConnSummary struct {
|
||||||
|
Addr string
|
||||||
|
LocalAddr string
|
||||||
|
RemoteAddr string
|
||||||
|
Reused bool
|
||||||
|
WasIdle bool
|
||||||
|
IdleTime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceDNSSummary 是 DNS 解析摘要。
|
||||||
|
type TraceDNSSummary struct {
|
||||||
|
Host string
|
||||||
|
Addrs []string
|
||||||
|
Coalesced bool
|
||||||
|
StartedAt time.Time
|
||||||
|
CompletedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceConnectSummary 是单次建连尝试摘要。
|
||||||
|
type TraceConnectSummary struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
StartedAt time.Time
|
||||||
|
CompletedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceTLSSummary 是 TLS 握手与连接状态摘要。
|
||||||
|
type TraceTLSSummary struct {
|
||||||
|
Network string
|
||||||
|
Addr string
|
||||||
|
ServerName string
|
||||||
|
Version uint16
|
||||||
|
VersionName string
|
||||||
|
CipherSuite uint16
|
||||||
|
CipherSuiteName string
|
||||||
|
CurveID tls.CurveID
|
||||||
|
CurveName string
|
||||||
|
NegotiatedProtocol string
|
||||||
|
DidResume bool
|
||||||
|
ECHAccepted bool
|
||||||
|
VerifiedChains int
|
||||||
|
StartedAt time.Time
|
||||||
|
CompletedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
Err error
|
||||||
|
PeerCertificates []TraceCertificateSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceCertificateSummary 是单张证书的关键信息摘要。
|
||||||
|
type TraceCertificateSummary struct {
|
||||||
|
Subject string
|
||||||
|
Issuer string
|
||||||
|
DNSNames []string
|
||||||
|
IPAddresses []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TraceRecorder 聚合最近一次发布的 trace 摘要。
|
||||||
|
// 通过 Request/Client 绑定时,starnet 会为每次执行创建私有运行态并在完成后发布摘要;
|
||||||
|
// 直接使用 Hooks() 时,调用方仍需自行管理 Reset 与生命周期。
|
||||||
|
type TraceRecorder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
summary TraceSummary
|
||||||
|
pendingDNS []TraceDNSSummary
|
||||||
|
pendingConnectStarts map[string][]time.Time
|
||||||
|
pendingTLSStart time.Time
|
||||||
|
hooks *TraceHooks
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTraceRecorder 创建请求级 trace 记录器。
|
||||||
|
func NewTraceRecorder() *TraceRecorder {
|
||||||
|
recorder := &TraceRecorder{}
|
||||||
|
recorder.hooks = &TraceHooks{
|
||||||
|
GetConn: recorder.onGetConn,
|
||||||
|
GotConn: recorder.onGotConn,
|
||||||
|
DNSStart: recorder.onDNSStart,
|
||||||
|
DNSDone: recorder.onDNSDone,
|
||||||
|
ConnectStart: recorder.onConnectStart,
|
||||||
|
ConnectDone: recorder.onConnectDone,
|
||||||
|
TLSHandshakeStart: recorder.onTLSHandshakeStart,
|
||||||
|
TLSHandshakeDone: recorder.onTLSHandshakeDone,
|
||||||
|
WroteRequest: recorder.onWroteRequest,
|
||||||
|
GotFirstResponseByte: recorder.onGotFirstResponseByte,
|
||||||
|
}
|
||||||
|
return recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks 返回可挂到请求上的底层 trace hooks。
|
||||||
|
func (r *TraceRecorder) Hooks() *TraceHooks {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = &TraceHooks{
|
||||||
|
GetConn: r.onGetConn,
|
||||||
|
GotConn: r.onGotConn,
|
||||||
|
DNSStart: r.onDNSStart,
|
||||||
|
DNSDone: r.onDNSDone,
|
||||||
|
ConnectStart: r.onConnectStart,
|
||||||
|
ConnectDone: r.onConnectDone,
|
||||||
|
TLSHandshakeStart: r.onTLSHandshakeStart,
|
||||||
|
TLSHandshakeDone: r.onTLSHandshakeDone,
|
||||||
|
WroteRequest: r.onWroteRequest,
|
||||||
|
GotFirstResponseByte: r.onGotFirstResponseByte,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.hooks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 清空当前摘要和内部状态。
|
||||||
|
func (r *TraceRecorder) Reset() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.resetLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary 返回当前 trace 摘要的快照。
|
||||||
|
func (r *TraceRecorder) Summary() TraceSummary {
|
||||||
|
if r == nil {
|
||||||
|
return TraceSummary{}
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
return cloneTraceSummary(r.summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) forkExecution() *TraceRecorder {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return NewTraceRecorder()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) publishSummary(summary TraceSummary) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.summary = cloneTraceSummary(summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) startRequest() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.resetLocked()
|
||||||
|
r.summary.StartedAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) observePreparedRequest(req *http.Request) {
|
||||||
|
if r == nil || req == nil || req.URL == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.ensureStartedLocked(time.Now())
|
||||||
|
r.summary.Method = req.Method
|
||||||
|
r.summary.URL = req.URL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) observeResponse(resp *http.Response) {
|
||||||
|
if r == nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
r.summary.ResponseAt = now
|
||||||
|
r.summary.StatusCode = resp.StatusCode
|
||||||
|
r.summary.ResponseProto = resp.Proto
|
||||||
|
if resp.TLS != nil {
|
||||||
|
r.summary.TLS = mergeTraceTLSSummary(r.summary.TLS, *resp.TLS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) ensureStartedLocked(now time.Time) {
|
||||||
|
if r.summary.StartedAt.IsZero() {
|
||||||
|
r.summary.StartedAt = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) resetLocked() {
|
||||||
|
r.summary = TraceSummary{}
|
||||||
|
r.pendingDNS = nil
|
||||||
|
r.pendingTLSStart = time.Time{}
|
||||||
|
if len(r.pendingConnectStarts) == 0 {
|
||||||
|
r.pendingConnectStarts = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key := range r.pendingConnectStarts {
|
||||||
|
delete(r.pendingConnectStarts, key)
|
||||||
|
}
|
||||||
|
r.pendingConnectStarts = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onGetConn(info TraceGetConnInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.ensureStartedLocked(time.Now())
|
||||||
|
r.summary.Conn.Addr = info.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onGotConn(info TraceGotConnInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
r.summary.Conn.Reused = info.Reused
|
||||||
|
r.summary.Conn.WasIdle = info.WasIdle
|
||||||
|
r.summary.Conn.IdleTime = info.IdleTime
|
||||||
|
if info.Conn != nil {
|
||||||
|
r.summary.Conn.LocalAddr = traceAddrString(info.Conn.LocalAddr())
|
||||||
|
r.summary.Conn.RemoteAddr = traceAddrString(info.Conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onDNSStart(info TraceDNSStartInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
dns := TraceDNSSummary{
|
||||||
|
Host: info.Host,
|
||||||
|
StartedAt: now,
|
||||||
|
}
|
||||||
|
r.pendingDNS = append(r.pendingDNS, dns)
|
||||||
|
copyDNS := dns
|
||||||
|
r.summary.DNS = ©DNS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onDNSDone(info TraceDNSDoneInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
|
||||||
|
dns := TraceDNSSummary{
|
||||||
|
Host: "",
|
||||||
|
Addrs: traceIPAddrsToStrings(info.Addrs),
|
||||||
|
Coalesced: info.Coalesced,
|
||||||
|
CompletedAt: now,
|
||||||
|
Err: info.Err,
|
||||||
|
}
|
||||||
|
if len(r.pendingDNS) > 0 {
|
||||||
|
dns.Host = r.pendingDNS[0].Host
|
||||||
|
dns.StartedAt = r.pendingDNS[0].StartedAt
|
||||||
|
if len(r.pendingDNS) == 1 {
|
||||||
|
r.pendingDNS = nil
|
||||||
|
} else {
|
||||||
|
r.pendingDNS = append([]TraceDNSSummary(nil), r.pendingDNS[1:]...)
|
||||||
|
}
|
||||||
|
} else if r.summary.DNS != nil {
|
||||||
|
dns.Host = r.summary.DNS.Host
|
||||||
|
dns.StartedAt = r.summary.DNS.StartedAt
|
||||||
|
}
|
||||||
|
if !dns.StartedAt.IsZero() {
|
||||||
|
dns.Duration = now.Sub(dns.StartedAt)
|
||||||
|
}
|
||||||
|
r.summary.DNSEvents = append(r.summary.DNSEvents, dns)
|
||||||
|
copyDNS := dns
|
||||||
|
r.summary.DNS = ©DNS
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onConnectStart(info TraceConnectStartInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
if r.pendingConnectStarts == nil {
|
||||||
|
r.pendingConnectStarts = make(map[string][]time.Time)
|
||||||
|
}
|
||||||
|
key := traceConnectKey(info.Network, info.Addr)
|
||||||
|
r.pendingConnectStarts[key] = append(r.pendingConnectStarts[key], now)
|
||||||
|
r.summary.Connect = append(r.summary.Connect, TraceConnectSummary{
|
||||||
|
Network: info.Network,
|
||||||
|
Addr: info.Addr,
|
||||||
|
StartedAt: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onConnectDone(info TraceConnectDoneInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
|
||||||
|
start := time.Time{}
|
||||||
|
key := traceConnectKey(info.Network, info.Addr)
|
||||||
|
if starts := r.pendingConnectStarts[key]; len(starts) > 0 {
|
||||||
|
start = starts[0]
|
||||||
|
if len(starts) == 1 {
|
||||||
|
delete(r.pendingConnectStarts, key)
|
||||||
|
} else {
|
||||||
|
r.pendingConnectStarts[key] = starts[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
connect := TraceConnectSummary{
|
||||||
|
Network: info.Network,
|
||||||
|
Addr: info.Addr,
|
||||||
|
StartedAt: start,
|
||||||
|
CompletedAt: now,
|
||||||
|
Err: info.Err,
|
||||||
|
}
|
||||||
|
if !start.IsZero() {
|
||||||
|
connect.Duration = now.Sub(start)
|
||||||
|
}
|
||||||
|
|
||||||
|
for index := len(r.summary.Connect) - 1; index >= 0; index-- {
|
||||||
|
item := &r.summary.Connect[index]
|
||||||
|
if item.Network != info.Network || item.Addr != info.Addr || !item.CompletedAt.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
item.CompletedAt = now
|
||||||
|
item.Duration = connect.Duration
|
||||||
|
item.Err = info.Err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.summary.Connect = append(r.summary.Connect, connect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onTLSHandshakeStart(info TraceTLSHandshakeStartInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
r.pendingTLSStart = now
|
||||||
|
r.summary.TLS = &TraceTLSSummary{
|
||||||
|
Network: info.Network,
|
||||||
|
Addr: info.Addr,
|
||||||
|
ServerName: info.ServerName,
|
||||||
|
StartedAt: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onTLSHandshakeDone(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
|
||||||
|
var tlsSummary *TraceTLSSummary
|
||||||
|
if r.summary.TLS != nil {
|
||||||
|
copied := *r.summary.TLS
|
||||||
|
tlsSummary = &copied
|
||||||
|
} else {
|
||||||
|
tlsSummary = &TraceTLSSummary{}
|
||||||
|
}
|
||||||
|
if tlsSummary.Network == "" {
|
||||||
|
tlsSummary.Network = info.Network
|
||||||
|
}
|
||||||
|
if tlsSummary.Addr == "" {
|
||||||
|
tlsSummary.Addr = info.Addr
|
||||||
|
}
|
||||||
|
if tlsSummary.ServerName == "" {
|
||||||
|
tlsSummary.ServerName = info.ServerName
|
||||||
|
}
|
||||||
|
if tlsSummary.StartedAt.IsZero() {
|
||||||
|
tlsSummary.StartedAt = r.pendingTLSStart
|
||||||
|
}
|
||||||
|
tlsSummary.CompletedAt = now
|
||||||
|
if !tlsSummary.StartedAt.IsZero() {
|
||||||
|
tlsSummary.Duration = now.Sub(tlsSummary.StartedAt)
|
||||||
|
}
|
||||||
|
tlsSummary.Err = info.Err
|
||||||
|
tlsSummary = mergeTraceTLSSummary(tlsSummary, info.ConnectionState)
|
||||||
|
if tlsSummary.ServerName == "" {
|
||||||
|
tlsSummary.ServerName = info.ServerName
|
||||||
|
}
|
||||||
|
if tlsSummary.Addr == "" {
|
||||||
|
tlsSummary.Addr = info.Addr
|
||||||
|
}
|
||||||
|
if tlsSummary.Network == "" {
|
||||||
|
tlsSummary.Network = info.Network
|
||||||
|
}
|
||||||
|
r.pendingTLSStart = time.Time{}
|
||||||
|
r.summary.TLS = tlsSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onWroteRequest(info TraceWroteRequestInfo) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
r.summary.RequestWrittenAt = now
|
||||||
|
r.summary.RequestWriteErr = info.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TraceRecorder) onGotFirstResponseByte() {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.ensureStartedLocked(now)
|
||||||
|
r.summary.FirstResponseByteAt = now
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceAddrString(addr net.Addr) string {
|
||||||
|
if addr == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceConnectKey(network, addr string) string {
|
||||||
|
return network + "\x00" + addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func traceIPAddrsToStrings(addrs []net.IPAddr) []string {
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(addrs))
|
||||||
|
for _, addr := range addrs {
|
||||||
|
out = append(out, addr.String())
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeTraceTLSSummary(summary *TraceTLSSummary, state tls.ConnectionState) *TraceTLSSummary {
|
||||||
|
if summary == nil {
|
||||||
|
summary = &TraceTLSSummary{}
|
||||||
|
}
|
||||||
|
if state.Version != 0 {
|
||||||
|
summary.Version = state.Version
|
||||||
|
summary.VersionName = tls.VersionName(state.Version)
|
||||||
|
}
|
||||||
|
if state.CipherSuite != 0 {
|
||||||
|
summary.CipherSuite = state.CipherSuite
|
||||||
|
summary.CipherSuiteName = tls.CipherSuiteName(state.CipherSuite)
|
||||||
|
}
|
||||||
|
if state.CurveID != 0 {
|
||||||
|
summary.CurveID = state.CurveID
|
||||||
|
summary.CurveName = state.CurveID.String()
|
||||||
|
}
|
||||||
|
if state.ServerName != "" {
|
||||||
|
summary.ServerName = state.ServerName
|
||||||
|
}
|
||||||
|
if state.NegotiatedProtocol != "" {
|
||||||
|
summary.NegotiatedProtocol = state.NegotiatedProtocol
|
||||||
|
}
|
||||||
|
summary.DidResume = state.DidResume
|
||||||
|
summary.ECHAccepted = state.ECHAccepted
|
||||||
|
summary.VerifiedChains = len(state.VerifiedChains)
|
||||||
|
if len(state.PeerCertificates) > 0 {
|
||||||
|
summary.PeerCertificates = make([]TraceCertificateSummary, 0, len(state.PeerCertificates))
|
||||||
|
for _, cert := range state.PeerCertificates {
|
||||||
|
certSummary := TraceCertificateSummary{
|
||||||
|
Subject: cert.Subject.String(),
|
||||||
|
Issuer: cert.Issuer.String(),
|
||||||
|
DNSNames: append([]string(nil), cert.DNSNames...),
|
||||||
|
}
|
||||||
|
if len(cert.IPAddresses) > 0 {
|
||||||
|
certSummary.IPAddresses = make([]string, 0, len(cert.IPAddresses))
|
||||||
|
for _, ip := range cert.IPAddresses {
|
||||||
|
certSummary.IPAddresses = append(certSummary.IPAddresses, ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary.PeerCertificates = append(summary.PeerCertificates, certSummary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneTraceSummary(summary TraceSummary) TraceSummary {
|
||||||
|
cloned := summary
|
||||||
|
if summary.DNS != nil {
|
||||||
|
dns := *summary.DNS
|
||||||
|
dns.Addrs = append([]string(nil), summary.DNS.Addrs...)
|
||||||
|
cloned.DNS = &dns
|
||||||
|
}
|
||||||
|
if len(summary.DNSEvents) > 0 {
|
||||||
|
cloned.DNSEvents = make([]TraceDNSSummary, 0, len(summary.DNSEvents))
|
||||||
|
for _, dns := range summary.DNSEvents {
|
||||||
|
cloned.DNSEvents = append(cloned.DNSEvents, TraceDNSSummary{
|
||||||
|
Host: dns.Host,
|
||||||
|
Addrs: append([]string(nil), dns.Addrs...),
|
||||||
|
Coalesced: dns.Coalesced,
|
||||||
|
StartedAt: dns.StartedAt,
|
||||||
|
CompletedAt: dns.CompletedAt,
|
||||||
|
Duration: dns.Duration,
|
||||||
|
Err: dns.Err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(summary.Connect) > 0 {
|
||||||
|
cloned.Connect = append([]TraceConnectSummary(nil), summary.Connect...)
|
||||||
|
}
|
||||||
|
if summary.TLS != nil {
|
||||||
|
tlsSummary := *summary.TLS
|
||||||
|
if len(summary.TLS.PeerCertificates) > 0 {
|
||||||
|
tlsSummary.PeerCertificates = make([]TraceCertificateSummary, 0, len(summary.TLS.PeerCertificates))
|
||||||
|
for _, cert := range summary.TLS.PeerCertificates {
|
||||||
|
tlsSummary.PeerCertificates = append(tlsSummary.PeerCertificates, TraceCertificateSummary{
|
||||||
|
Subject: cert.Subject,
|
||||||
|
Issuer: cert.Issuer,
|
||||||
|
DNSNames: append([]string(nil), cert.DNSNames...),
|
||||||
|
IPAddresses: append([]string(nil), cert.IPAddresses...),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cloned.TLS = &tlsSummary
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneTraceSummaryPtr(summary TraceSummary) *TraceSummary {
|
||||||
|
cloned := cloneTraceSummary(summary)
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
@@ -0,0 +1,488 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTraceRecorderCapturesHTTPSummary(t *testing.T) {
|
||||||
|
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
req := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTraceRecorder(recorder)
|
||||||
|
|
||||||
|
resp, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := recorder.Summary()
|
||||||
|
if summary.Method != http.MethodGet {
|
||||||
|
t.Fatalf("method=%q", summary.Method)
|
||||||
|
}
|
||||||
|
if summary.URL != server.URL {
|
||||||
|
t.Fatalf("url=%q", summary.URL)
|
||||||
|
}
|
||||||
|
if summary.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d", summary.StatusCode)
|
||||||
|
}
|
||||||
|
if summary.ResponseProto == "" {
|
||||||
|
t.Fatal("expected response proto")
|
||||||
|
}
|
||||||
|
if summary.RequestWrittenAt.IsZero() {
|
||||||
|
t.Fatal("expected request write timestamp")
|
||||||
|
}
|
||||||
|
if summary.FirstResponseByteAt.IsZero() {
|
||||||
|
t.Fatal("expected first response byte timestamp")
|
||||||
|
}
|
||||||
|
if summary.Conn.Addr == "" {
|
||||||
|
t.Fatal("expected get-conn target address")
|
||||||
|
}
|
||||||
|
if summary.TLS == nil {
|
||||||
|
t.Fatal("expected tls summary")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsSummary := summary.TLS
|
||||||
|
if tlsSummary.Version == 0 || tlsSummary.VersionName == "" {
|
||||||
|
t.Fatalf("unexpected tls version summary: %+v", tlsSummary)
|
||||||
|
}
|
||||||
|
if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" {
|
||||||
|
t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary)
|
||||||
|
}
|
||||||
|
if tlsSummary.ServerName == "" {
|
||||||
|
t.Fatal("expected tls server name")
|
||||||
|
}
|
||||||
|
if resp.TLS == nil {
|
||||||
|
t.Fatal("expected response TLS state")
|
||||||
|
}
|
||||||
|
if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol {
|
||||||
|
t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
if len(tlsSummary.PeerCertificates) == 0 {
|
||||||
|
t.Fatal("expected certificate summaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf := tlsSummary.PeerCertificates[0]
|
||||||
|
if leaf.Subject == "" || leaf.Issuer == "" {
|
||||||
|
t.Fatalf("unexpected leaf certificate summary: %+v", leaf)
|
||||||
|
}
|
||||||
|
if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 {
|
||||||
|
t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("request trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("response trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port)
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(targetURL, http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||||||
|
}).
|
||||||
|
SetTraceRecorder(recorder).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := recorder.Summary()
|
||||||
|
if summary.DNS == nil {
|
||||||
|
t.Fatal("expected dns summary")
|
||||||
|
}
|
||||||
|
if summary.DNS.Host != "trace-summary.example.test" {
|
||||||
|
t.Fatalf("dns host=%q", summary.DNS.Host)
|
||||||
|
}
|
||||||
|
if len(summary.DNS.Addrs) == 0 {
|
||||||
|
t.Fatal("expected resolved addresses")
|
||||||
|
}
|
||||||
|
if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) {
|
||||||
|
t.Fatalf("dns addrs=%v", summary.DNS.Addrs)
|
||||||
|
}
|
||||||
|
if summary.DNS.CompletedAt.IsZero() {
|
||||||
|
t.Fatal("expected dns completion timestamp")
|
||||||
|
}
|
||||||
|
if len(summary.Connect) == 0 {
|
||||||
|
t.Fatal("expected connect attempts")
|
||||||
|
}
|
||||||
|
|
||||||
|
connect := summary.Connect[0]
|
||||||
|
if connect.Network == "" || connect.Addr == "" {
|
||||||
|
t.Fatalf("unexpected connect summary: %+v", connect)
|
||||||
|
}
|
||||||
|
if connect.CompletedAt.IsZero() {
|
||||||
|
t.Fatalf("expected connect completion timestamp: %+v", connect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) {
|
||||||
|
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
|
||||||
|
firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := firstResp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("first Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
if err := firstResp.Close(); err != nil {
|
||||||
|
t.Fatalf("first Close() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTraceRecorder(recorder).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer secondResp.Close()
|
||||||
|
|
||||||
|
if _, err := secondResp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("second Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := recorder.Summary()
|
||||||
|
if !summary.Conn.Reused {
|
||||||
|
t.Fatalf("expected reused connection summary, got %+v", summary.Conn)
|
||||||
|
}
|
||||||
|
if summary.TLS == nil || summary.TLS.Version == 0 {
|
||||||
|
t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
wroteRequest := 0
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetTraceHooks(&TraceHooks{
|
||||||
|
WroteRequest: func(info TraceWroteRequestInfo) {
|
||||||
|
wroteRequest++
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
SetTraceRecorder(recorder).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
if wroteRequest == 0 {
|
||||||
|
t.Fatal("expected custom trace hook to run")
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := recorder.Summary()
|
||||||
|
if summary.RequestWrittenAt.IsZero() {
|
||||||
|
t.Fatal("expected recorder to capture wrote-request event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) {
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
hooks := recorder.Hooks()
|
||||||
|
|
||||||
|
hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"})
|
||||||
|
hooks.DNSDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}},
|
||||||
|
})
|
||||||
|
|
||||||
|
hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"})
|
||||||
|
hooks.DNSDone(TraceDNSDoneInfo{
|
||||||
|
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}},
|
||||||
|
})
|
||||||
|
|
||||||
|
summary := recorder.Summary()
|
||||||
|
if len(summary.DNSEvents) != 2 {
|
||||||
|
t.Fatalf("dns events=%d", len(summary.DNSEvents))
|
||||||
|
}
|
||||||
|
if summary.DNSEvents[0].Host != "target.example.test" {
|
||||||
|
t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host)
|
||||||
|
}
|
||||||
|
if summary.DNSEvents[1].Host != "proxy.example.test" {
|
||||||
|
t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host)
|
||||||
|
}
|
||||||
|
if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" {
|
||||||
|
t.Fatalf("last dns summary=%+v", summary.DNS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
base := newBaseHTTPTransport()
|
||||||
|
base.TLSClientConfig = &tls.Config{RootCAs: pool}
|
||||||
|
transport.SetBase(base)
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(t, server, "localhost")
|
||||||
|
var startInfo TraceTLSHandshakeStartInfo
|
||||||
|
var doneInfo TraceTLSHandshakeDoneInfo
|
||||||
|
|
||||||
|
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
|
||||||
|
SetTraceHooks(&TraceHooks{
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
startInfo = info
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
doneInfo = info
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantAddr := strings.TrimPrefix(targetURL, "https://")
|
||||||
|
if startInfo.Network != "tcp" {
|
||||||
|
t.Fatalf("start network=%q", startInfo.Network)
|
||||||
|
}
|
||||||
|
if startInfo.Addr != wantAddr {
|
||||||
|
t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr)
|
||||||
|
}
|
||||||
|
if startInfo.ServerName != "localhost" {
|
||||||
|
t.Fatalf("start server name=%q", startInfo.ServerName)
|
||||||
|
}
|
||||||
|
if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" {
|
||||||
|
t.Fatalf("done info=%+v", doneInfo)
|
||||||
|
}
|
||||||
|
if doneInfo.ConnectionState.Version == 0 {
|
||||||
|
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) {
|
||||||
|
var captured []string
|
||||||
|
traceState := newTraceState(&TraceHooks{
|
||||||
|
WroteHeaderField: func(info TraceWroteHeaderFieldInfo) {
|
||||||
|
captured = info.Values
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
trace := traceState.clientTrace()
|
||||||
|
values := []string{"a", "b"}
|
||||||
|
trace.WroteHeaderField("X-Test", values)
|
||||||
|
values[0] = "mutated"
|
||||||
|
|
||||||
|
if len(captured) != 2 {
|
||||||
|
t.Fatalf("captured=%v", captured)
|
||||||
|
}
|
||||||
|
if captured[0] != "a" {
|
||||||
|
t.Fatalf("captured=%v", captured)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(r.URL.Path))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
recorder := NewTraceRecorder()
|
||||||
|
client := NewClientNoErr(WithTraceRecorder(recorder))
|
||||||
|
|
||||||
|
req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet)
|
||||||
|
resp1, err := req1.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp1.Close()
|
||||||
|
if _, err := resp1.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("first Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req2 := req1.Clone().SetURL(server.URL + "/two")
|
||||||
|
resp2, err := req2.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
if _, err := resp2.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("second Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
|
||||||
|
t.Fatalf("req1 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
|
||||||
|
t.Fatalf("resp1 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
|
||||||
|
t.Fatalf("req2 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
|
||||||
|
t.Fatalf("resp2 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := recorder.Summary(); got.URL != server.URL+"/two" {
|
||||||
|
t.Fatalf("shared recorder summary=%+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(r.URL.Path))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req := NewSimpleRequest(server.URL+"/first", http.MethodGet).
|
||||||
|
SetTraceRecorder(NewTraceRecorder())
|
||||||
|
|
||||||
|
resp1, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp1.Close()
|
||||||
|
if _, err := resp1.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("first Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.SetURL(server.URL + "/second")
|
||||||
|
resp2, err := req.Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Close()
|
||||||
|
if _, err := resp2.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("second Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" {
|
||||||
|
t.Fatalf("resp1 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
|
||||||
|
t.Fatalf("request trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
|
||||||
|
t.Fatalf("resp2 trace summary=%+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) {
|
||||||
|
server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
base := newBaseHTTPTransport()
|
||||||
|
base.TLSClientConfig = &tls.Config{RootCAs: pool}
|
||||||
|
transport.SetBase(base)
|
||||||
|
|
||||||
|
targetURL := httpsURLForHost(t, server, "trace-custom.example.test")
|
||||||
|
serverAddr := server.Listener.Addr().String()
|
||||||
|
|
||||||
|
var startInfo TraceTLSHandshakeStartInfo
|
||||||
|
var doneInfo TraceTLSHandshakeDoneInfo
|
||||||
|
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
|
||||||
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr)
|
||||||
|
}).
|
||||||
|
SetTraceHooks(&TraceHooks{
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
startInfo = info
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
doneInfo = info
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if startInfo.Network != "" || startInfo.Addr != "" {
|
||||||
|
t.Fatalf("start info=%+v", startInfo)
|
||||||
|
}
|
||||||
|
if doneInfo.Network != "" || doneInfo.Addr != "" {
|
||||||
|
t.Fatalf("done info=%+v", doneInfo)
|
||||||
|
}
|
||||||
|
if startInfo.ServerName != "trace-custom.example.test" {
|
||||||
|
t.Fatalf("start server name=%q", startInfo.ServerName)
|
||||||
|
}
|
||||||
|
if doneInfo.ServerName != "trace-custom.example.test" {
|
||||||
|
t.Fatalf("done server name=%q", doneInfo.ServerName)
|
||||||
|
}
|
||||||
|
if doneInfo.ConnectionState.Version == 0 {
|
||||||
|
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
|
||||||
|
}
|
||||||
|
}
|
||||||
+324
@@ -0,0 +1,324 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
events := map[string]int{}
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
GetConn: func(info TraceGetConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["get_conn"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
GotConn: func(info TraceGotConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["got_conn"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["tls_start"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["tls_done"]++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected tls handshake error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
WroteHeaders: func() {
|
||||||
|
mu.Lock()
|
||||||
|
events["wrote_headers"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
WroteRequest: func(info TraceWroteRequestInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
events["wrote_request"]++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected write error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GotFirstResponseByte: func() {
|
||||||
|
mu.Lock()
|
||||||
|
events["first_byte"]++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
|
||||||
|
if events[key] == 0 {
|
||||||
|
t.Fatalf("expected trace event %q", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
tlsStartCount := 0
|
||||||
|
tlsDoneCount := 0
|
||||||
|
var lastInfo TraceTLSHandshakeDoneInfo
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
tlsStartCount++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
tlsDoneCount++
|
||||||
|
lastInfo = info
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetSkipTLSVerify(true).
|
||||||
|
SetDialTimeout(1500 * time.Millisecond).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if tlsStartCount != 1 {
|
||||||
|
t.Fatalf("tlsStartCount=%d", tlsStartCount)
|
||||||
|
}
|
||||||
|
if tlsDoneCount != 1 {
|
||||||
|
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
|
||||||
|
}
|
||||||
|
if lastInfo.Err != nil {
|
||||||
|
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
|
||||||
|
}
|
||||||
|
if lastInfo.ConnectionState.Version == 0 {
|
||||||
|
t.Fatal("expected tls connection state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
dnsStartCount := 0
|
||||||
|
dnsDoneCount := 0
|
||||||
|
var dnsStartHost string
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSStart: func(info TraceDNSStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsStartCount++
|
||||||
|
dnsStartHost = info.Host
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dnsDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected dns error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
|
||||||
|
resp, err := NewSimpleRequest(url, http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if dnsStartCount != 1 {
|
||||||
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
||||||
|
}
|
||||||
|
if dnsDoneCount != 1 {
|
||||||
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
||||||
|
}
|
||||||
|
if dnsStartHost != "trace.example.test" {
|
||||||
|
t.Fatalf("dnsStartHost=%q", dnsStartHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
connectStartCount := 0
|
||||||
|
connectDoneCount := 0
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
ConnectStart: func(info TraceConnectStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
connectStartCount++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
ConnectDone: func(info TraceConnectDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
connectDoneCount++
|
||||||
|
mu.Unlock()
|
||||||
|
if info.Err != nil {
|
||||||
|
t.Errorf("unexpected connect error: %v", info.Err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
var dialer net.Dialer
|
||||||
|
return dialer.DialContext(context.Background(), network, addr)
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if connectStartCount != 1 {
|
||||||
|
t.Fatalf("connectStartCount=%d", connectStartCount)
|
||||||
|
}
|
||||||
|
if connectDoneCount != 1 {
|
||||||
|
t.Fatalf("connectDoneCount=%d", connectDoneCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksRetryEvents(t *testing.T) {
|
||||||
|
var hits int
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits++
|
||||||
|
if hits == 1 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
starts := 0
|
||||||
|
dones := 0
|
||||||
|
backoffs := 0
|
||||||
|
var finalDone TraceRetryAttemptDoneInfo
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
starts++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
dones++
|
||||||
|
finalDone = info
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
RetryBackoff: func(info TraceRetryBackoffInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
backoffs++
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||||||
|
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if starts != 2 {
|
||||||
|
t.Fatalf("starts=%d", starts)
|
||||||
|
}
|
||||||
|
if dones != 2 {
|
||||||
|
t.Fatalf("dones=%d", dones)
|
||||||
|
}
|
||||||
|
if backoffs != 1 {
|
||||||
|
t.Fatalf("backoffs=%d", backoffs)
|
||||||
|
}
|
||||||
|
if finalDone.WillRetry {
|
||||||
|
t.Fatal("expected final attempt not to retry")
|
||||||
|
}
|
||||||
|
if finalDone.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("final status=%d", finalDone.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
|
||||||
|
var gotErr error
|
||||||
|
hooks := &TraceHooks{
|
||||||
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||||||
|
gotErr = info.Err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
|
||||||
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||||||
|
return nil, errors.New("lookup failed")
|
||||||
|
}).
|
||||||
|
SetTraceHooks(hooks).
|
||||||
|
Do()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected request error")
|
||||||
|
}
|
||||||
|
if gotErr == nil || gotErr.Error() != "lookup failed" {
|
||||||
|
t.Fatalf("gotErr=%v", gotErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
+416
@@ -0,0 +1,416 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dynamicTransportCacheMaxEntries = 64
|
||||||
|
|
||||||
|
type dynamicTransportCacheKey struct {
|
||||||
|
proxyKey string
|
||||||
|
dialTimeout time.Duration
|
||||||
|
customIPs string
|
||||||
|
customDNS string
|
||||||
|
tlsServerName string
|
||||||
|
skipVerify bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport 自定义 Transport(支持请求级配置)
|
||||||
|
type Transport struct {
|
||||||
|
base *http.Transport
|
||||||
|
dynamicCache map[dynamicTransportCacheKey]*http.Transport
|
||||||
|
dynamicCacheOrder []dynamicTransportCacheKey
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip 实现 http.RoundTripper 接口
|
||||||
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
t.ensureBase()
|
||||||
|
|
||||||
|
// 提取请求级别的配置
|
||||||
|
reqCtx := getRequestContext(req.Context())
|
||||||
|
traceState := getTraceState(req.Context())
|
||||||
|
execReq := req
|
||||||
|
execReqCtx := reqCtx
|
||||||
|
var targetAddrs []string
|
||||||
|
|
||||||
|
// 优先级1:完全自定义的 transport
|
||||||
|
if execReqCtx.Transport != nil {
|
||||||
|
return execReqCtx.Transport.RoundTrip(execReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先级2:需要动态配置
|
||||||
|
if needsDynamicTransport(execReqCtx) {
|
||||||
|
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
|
||||||
|
if len(targetAddrs) > 0 {
|
||||||
|
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
|
||||||
|
}
|
||||||
|
return dynamicTransport.RoundTrip(execReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先级3:使用基础 transport
|
||||||
|
t.mu.RLock()
|
||||||
|
baseTransport := t.base
|
||||||
|
t.mu.RUnlock()
|
||||||
|
if len(targetAddrs) > 0 {
|
||||||
|
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
|
||||||
|
}
|
||||||
|
return baseTransport.RoundTrip(execReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBaseHTTPTransport() *http.Transport {
|
||||||
|
return &http.Transport{
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) ensureBase() {
|
||||||
|
if t.base != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.ensureBaseLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) ensureBaseLocked() {
|
||||||
|
if t.base == nil {
|
||||||
|
t.base = newBaseHTTPTransport()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
if key, ok := newDynamicTransportCacheKey(rc); ok {
|
||||||
|
return t.getOrCreateCachedDynamicTransport(key, rc)
|
||||||
|
}
|
||||||
|
return t.buildDynamicTransport(rc, traceState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
|
||||||
|
t.mu.RLock()
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
t.mu.RUnlock()
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
t.mu.RUnlock()
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
t.ensureBaseLocked()
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := buildDynamicTransportFromBase(t.base, rc, nil)
|
||||||
|
if t.dynamicCache == nil {
|
||||||
|
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
|
||||||
|
}
|
||||||
|
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
|
||||||
|
oldestKey := t.dynamicCacheOrder[0]
|
||||||
|
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
|
||||||
|
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
|
||||||
|
oldest.CloseIdleConnections()
|
||||||
|
delete(t.dynamicCache, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.dynamicCache[key] = transport
|
||||||
|
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) resetDynamicTransportCacheLocked() {
|
||||||
|
for _, key := range t.dynamicCacheOrder {
|
||||||
|
if transport := t.dynamicCache[key]; transport != nil {
|
||||||
|
transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.dynamicCache = nil
|
||||||
|
t.dynamicCacheOrder = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
|
||||||
|
if rc == nil {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
|
||||||
|
return dynamicTransportCacheKey{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := dynamicTransportCacheKey{
|
||||||
|
proxyKey: normalizeProxyCacheKey(rc.Proxy),
|
||||||
|
dialTimeout: rc.DialTimeout,
|
||||||
|
customIPs: serializeTransportCacheList(rc.CustomIP),
|
||||||
|
customDNS: serializeTransportCacheList(rc.CustomDNS),
|
||||||
|
tlsServerName: effectiveTLSServerName(rc),
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil {
|
||||||
|
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
|
||||||
|
}
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeProxyCacheKey(proxy string) string {
|
||||||
|
if proxy == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
proxyURL, err := parseProxyURL(proxy)
|
||||||
|
if err != nil {
|
||||||
|
return "\x00invalid:" + proxy
|
||||||
|
}
|
||||||
|
return proxyURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeTransportCacheList(values []string) string {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var builder strings.Builder
|
||||||
|
for _, value := range values {
|
||||||
|
builder.WriteString(value)
|
||||||
|
builder.WriteByte(0)
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveTLSServerName(rc *RequestContext) string {
|
||||||
|
if rc == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
|
||||||
|
return rc.TLSConfig.ServerName
|
||||||
|
}
|
||||||
|
return rc.TLSServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDynamicTransport 构建动态 Transport
|
||||||
|
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
t.ensureBase()
|
||||||
|
t.mu.RLock()
|
||||||
|
baseTransport := t.base
|
||||||
|
t.mu.RUnlock()
|
||||||
|
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
|
||||||
|
transport := baseTransport.Clone()
|
||||||
|
|
||||||
|
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify)
|
||||||
|
if rc.TLSConfig != nil {
|
||||||
|
transport.TLSClientConfig = rc.TLSConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用代理配置
|
||||||
|
if rc.Proxy != "" {
|
||||||
|
proxyURL, err := parseProxyURL(rc.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
transport.Proxy = func(*http.Request) (*url.URL, error) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
transport.Proxy = http.ProxyURL(proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用自定义 Dial 函数
|
||||||
|
if rc.DialFn != nil {
|
||||||
|
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
|
||||||
|
dialFn := rc.DialFn
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if traceState.hooks.ConnectStart != nil {
|
||||||
|
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||||||
|
}
|
||||||
|
conn, err := dialFn(ctx, network, addr)
|
||||||
|
if traceState.hooks.ConnectDone != nil {
|
||||||
|
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
transport.DialContext = rc.DialFn
|
||||||
|
}
|
||||||
|
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
|
||||||
|
// 使用默认 Dial 函数(会从 context 读取配置)
|
||||||
|
transport.DialContext = defaultDialFunc
|
||||||
|
transport.DialTLSContext = defaultDialTLSFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
return transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base 获取基础 Transport
|
||||||
|
func (t *Transport) Base() *http.Transport {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.base
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBase 设置基础 Transport
|
||||||
|
func (t *Transport) SetBase(base *http.Transport) {
|
||||||
|
t.mu.Lock()
|
||||||
|
t.base = base
|
||||||
|
t.resetDynamicTransportCacheLocked()
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
|
||||||
|
if req == nil || req.URL == nil || reqCtx == nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
host := req.URL.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
if len(targetAddrs) == 0 {
|
||||||
|
return req, reqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
execReqCtx := *reqCtx
|
||||||
|
execReqCtx.CustomIP = nil
|
||||||
|
execReqCtx.CustomDNS = nil
|
||||||
|
execReqCtx.LookupIPFn = nil
|
||||||
|
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
|
||||||
|
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
|
||||||
|
execReqCtx.TLSConfigCacheable = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
execCtx := clearTargetResolutionContext(req.Context())
|
||||||
|
execReq := req.Clone(execCtx)
|
||||||
|
execReq.Host = req.Host
|
||||||
|
if len(targetAddrs) == 1 {
|
||||||
|
execReq.URL.Host = targetAddrs[0]
|
||||||
|
return execReq, &execReqCtx, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return execReq, &execReqCtx, targetAddrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearTargetResolutionContext(ctx context.Context) context.Context {
|
||||||
|
if v := ctx.Value(ctxKeyRequestContext); v != nil {
|
||||||
|
if rc, ok := v.(*RequestContext); ok && rc != nil {
|
||||||
|
cloned := cloneRequestContext(rc)
|
||||||
|
cloned.CustomIP = nil
|
||||||
|
cloned.CustomDNS = nil
|
||||||
|
cloned.LookupIPFn = nil
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
|
||||||
|
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
|
||||||
|
if serverName == "" {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
if cfg.ServerName != "" {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
cloned := cfg.Clone()
|
||||||
|
cloned.ServerName = serverName
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
|
||||||
|
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
|
||||||
|
return rt.RoundTrip(baseReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
|
||||||
|
targetAddrs = targetAddrs[:1]
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for _, targetAddr := range targetAddrs {
|
||||||
|
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := rt.RoundTrip(attemptReq)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
|
||||||
|
if req == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !isIdempotentMethod(req.Method) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.Body == nil || req.Body == http.NoBody {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return req.GetBody != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
|
||||||
|
req := baseReq.Clone(baseReq.Context())
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case baseReq.Body == nil || baseReq.Body == http.NoBody:
|
||||||
|
req.Body = baseReq.Body
|
||||||
|
case baseReq.GetBody != nil:
|
||||||
|
body, err := baseReq.GetBody()
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapError(err, "clone request body for resolved target")
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
default:
|
||||||
|
req.Body = baseReq.Body
|
||||||
|
}
|
||||||
|
|
||||||
|
req.URL.Host = targetAddr
|
||||||
|
req.Host = baseReq.Host
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(&RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
DialTimeout: 2 * time.Second,
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "cache.test",
|
||||||
|
}, nil)
|
||||||
|
second := transport.getDynamicTransport(&RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
DialTimeout: 2 * time.Second,
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "cache.test",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
if first != second {
|
||||||
|
t.Fatal("expected cached dynamic transport to be reused")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(&RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "first.test",
|
||||||
|
}, nil)
|
||||||
|
second := transport.getDynamicTransport(&RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSServerName: "second.test",
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected distinct tls server names to use different transports")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 2 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 2", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
|
||||||
|
transport := &Transport{base: newBaseHTTPTransport()}
|
||||||
|
reqCtx := &RequestContext{
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
first := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
second := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected user tls config to bypass dynamic transport cache")
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 0 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
|
||||||
|
client := NewClientNoErr()
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
|
||||||
|
first := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.SetDefaultSkipTLSVerify(true)
|
||||||
|
if got := len(transport.dynamicCache); got != 0 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
second := transport.getDynamicTransport(reqCtx, nil)
|
||||||
|
if first == second {
|
||||||
|
t.Fatal("expected cache reset after default tls change")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
|
||||||
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClientNoErr()
|
||||||
|
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
|
||||||
|
|
||||||
|
runRequest := func() bool {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
gotConn bool
|
||||||
|
reused bool
|
||||||
|
)
|
||||||
|
|
||||||
|
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
|
||||||
|
SetCustomIP([]string{"127.0.0.1"}).
|
||||||
|
SetTraceHooks(&TraceHooks{
|
||||||
|
GotConn: func(info TraceGotConnInfo) {
|
||||||
|
mu.Lock()
|
||||||
|
gotConn = true
|
||||||
|
reused = info.Reused
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
Do()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do() error: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
if _, err := resp.Body().Bytes(); err != nil {
|
||||||
|
t.Fatalf("Body().Bytes() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if !gotConn {
|
||||||
|
t.Fatal("expected GotConn trace event")
|
||||||
|
}
|
||||||
|
return reused
|
||||||
|
}
|
||||||
|
|
||||||
|
if runRequest() {
|
||||||
|
t.Fatal("first request unexpectedly reused a connection")
|
||||||
|
}
|
||||||
|
if !runRequest() {
|
||||||
|
t.Fatal("second request did not reuse cached dynamic transport connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, ok := client.HTTPClient().Transport.(*Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
|
||||||
|
}
|
||||||
|
if got := len(transport.dynamicCache); got != 1 {
|
||||||
|
t.Fatalf("dynamic cache size=%d; want 1", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
|
||||||
|
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
CustomIP: []string{"127.0.0.1"},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if execReq == req {
|
||||||
|
t.Fatal("expected cloned request for proxy target preparation")
|
||||||
|
}
|
||||||
|
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
|
||||||
|
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
|
||||||
|
}
|
||||||
|
if got := req.URL.Host; got != "proxy-single.test:8443" {
|
||||||
|
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
|
||||||
|
}
|
||||||
|
if len(targetAddrs) != 0 {
|
||||||
|
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
|
||||||
|
}
|
||||||
|
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
|
||||||
|
t.Fatal("expected synthesized tls config for single target proxy request")
|
||||||
|
}
|
||||||
|
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
|
||||||
|
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.NewRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
req.Host = req.URL.Host
|
||||||
|
|
||||||
|
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
|
||||||
|
Proxy: "http://127.0.0.1:8080",
|
||||||
|
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
|
||||||
|
t.Fatalf("execReq.URL.Host=%q; want original host", got)
|
||||||
|
}
|
||||||
|
if len(targetAddrs) != 2 {
|
||||||
|
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
|
||||||
|
}
|
||||||
|
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
|
||||||
|
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP Content-Type 常量
|
||||||
|
const (
|
||||||
|
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
|
||||||
|
ContentTypeFormData = "multipart/form-data"
|
||||||
|
ContentTypeJSON = "application/json"
|
||||||
|
ContentTypeXML = "application/xml"
|
||||||
|
ContentTypePlain = "text/plain"
|
||||||
|
ContentTypeHTML = "text/html"
|
||||||
|
ContentTypeOctetStream = "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 默认配置
|
||||||
|
const (
|
||||||
|
DefaultDialTimeout = 5 * time.Second
|
||||||
|
DefaultTimeout = 10 * time.Second
|
||||||
|
DefaultUserAgent = "starnet"
|
||||||
|
DefaultFetchRespBody = false
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestFile 表示要上传的文件
|
||||||
|
type RequestFile struct {
|
||||||
|
FormName string // 表单字段名
|
||||||
|
FileName string // 文件名
|
||||||
|
FilePath string // 文件路径(如果从文件读取)
|
||||||
|
FileData io.Reader // 文件数据流
|
||||||
|
FileSize int64 // 文件大小
|
||||||
|
FileType string // MIME 类型
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadProgressFunc 文件上传进度回调函数
|
||||||
|
type UploadProgressFunc func(filename string, uploaded int64, total int64)
|
||||||
|
|
||||||
|
// NetworkConfig 网络配置
|
||||||
|
type NetworkConfig struct {
|
||||||
|
Proxy string // 代理地址
|
||||||
|
DialTimeout time.Duration // 连接超时
|
||||||
|
Timeout time.Duration // 总超时
|
||||||
|
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig TLS 配置
|
||||||
|
type TLSConfig struct {
|
||||||
|
Config *tls.Config // TLS 配置
|
||||||
|
SkipVerify bool // 跳过证书验证
|
||||||
|
ServerName string // 显式 TLS ServerName/SNI 覆盖
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSConfig DNS 配置
|
||||||
|
type DNSConfig struct {
|
||||||
|
CustomIP []string // 直接指定 IP(最高优先级)
|
||||||
|
CustomDNS []string // 自定义 DNS 服务器
|
||||||
|
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
|
||||||
|
}
|
||||||
|
|
||||||
|
type bodyMode uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
bodyModeUnset bodyMode = iota
|
||||||
|
bodyModeBytes
|
||||||
|
bodyModeReader
|
||||||
|
bodyModeForm
|
||||||
|
bodyModeMultipart
|
||||||
|
)
|
||||||
|
|
||||||
|
// BodyConfig 请求体配置
|
||||||
|
type BodyConfig struct {
|
||||||
|
Mode bodyMode // 当前 body 来源模式
|
||||||
|
Bytes []byte // 原始字节
|
||||||
|
Reader io.Reader // 数据流
|
||||||
|
FormData map[string][]string // 表单数据
|
||||||
|
Files []RequestFile // 文件列表
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestConfig 请求配置(内部使用)
|
||||||
|
type RequestConfig struct {
|
||||||
|
Network NetworkConfig
|
||||||
|
TLS TLSConfig
|
||||||
|
DNS DNSConfig
|
||||||
|
Body BodyConfig
|
||||||
|
Headers http.Header
|
||||||
|
Cookies []*http.Cookie
|
||||||
|
Queries map[string][]string
|
||||||
|
|
||||||
|
// 其他配置
|
||||||
|
BasicAuth [2]string // Basic 认证
|
||||||
|
Host string // 显式 Host 头覆盖
|
||||||
|
ContentLength int64 // 手动设置的 Content-Length
|
||||||
|
AutoCalcContentLength bool // 自动计算 Content-Length
|
||||||
|
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
|
||||||
|
UploadProgress UploadProgressFunc // 上传进度回调
|
||||||
|
|
||||||
|
// Transport 配置
|
||||||
|
CustomTransport bool // 是否使用自定义 Transport
|
||||||
|
Transport *http.Transport // 自定义 Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone 克隆配置
|
||||||
|
func (c *RequestConfig) Clone() *RequestConfig {
|
||||||
|
return &RequestConfig{
|
||||||
|
Network: NetworkConfig{
|
||||||
|
Proxy: c.Network.Proxy,
|
||||||
|
DialTimeout: c.Network.DialTimeout,
|
||||||
|
Timeout: c.Network.Timeout,
|
||||||
|
DialFunc: c.Network.DialFunc,
|
||||||
|
},
|
||||||
|
TLS: TLSConfig{
|
||||||
|
Config: cloneTLSConfig(c.TLS.Config),
|
||||||
|
SkipVerify: c.TLS.SkipVerify,
|
||||||
|
ServerName: c.TLS.ServerName,
|
||||||
|
},
|
||||||
|
DNS: DNSConfig{
|
||||||
|
CustomIP: cloneStringSlice(c.DNS.CustomIP),
|
||||||
|
CustomDNS: cloneStringSlice(c.DNS.CustomDNS),
|
||||||
|
LookupFunc: c.DNS.LookupFunc,
|
||||||
|
},
|
||||||
|
Body: BodyConfig{
|
||||||
|
Mode: c.Body.Mode,
|
||||||
|
Bytes: cloneBytes(c.Body.Bytes),
|
||||||
|
Reader: c.Body.Reader, // Reader 不可克隆
|
||||||
|
FormData: cloneStringMapSlice(c.Body.FormData),
|
||||||
|
Files: cloneFiles(c.Body.Files),
|
||||||
|
},
|
||||||
|
Headers: cloneHeader(c.Headers),
|
||||||
|
Cookies: cloneCookies(c.Cookies),
|
||||||
|
Queries: cloneStringMapSlice(c.Queries),
|
||||||
|
BasicAuth: c.BasicAuth,
|
||||||
|
Host: c.Host,
|
||||||
|
ContentLength: c.ContentLength,
|
||||||
|
AutoCalcContentLength: c.AutoCalcContentLength,
|
||||||
|
MaxRespBodyBytes: c.MaxRespBodyBytes,
|
||||||
|
UploadProgress: c.UploadProgress,
|
||||||
|
CustomTransport: c.CustomTransport,
|
||||||
|
Transport: c.Transport, // Transport 共享
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestOpt 请求选项函数
|
||||||
|
type RequestOpt func(*Request) error
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// validMethod 验证 HTTP 方法是否有效
|
||||||
|
func validMethod(method string) bool {
|
||||||
|
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNotToken 检查字符是否不是 token 字符
|
||||||
|
func isNotToken(r rune) bool {
|
||||||
|
return !isTokenRune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTokenRune 检查字符是否是 token 字符
|
||||||
|
func isTokenRune(r rune) bool {
|
||||||
|
i := int(r)
|
||||||
|
return i < 127 && isTokenTable[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTokenTable token 字符表
|
||||||
|
var isTokenTable = [127]bool{
|
||||||
|
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
|
||||||
|
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
|
||||||
|
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
|
||||||
|
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
|
||||||
|
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
|
||||||
|
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
|
||||||
|
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
|
||||||
|
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
|
||||||
|
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
|
||||||
|
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
|
||||||
|
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasPort 检查地址是否包含端口
|
||||||
|
func hasPort(s string) bool {
|
||||||
|
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeEmptyPort 移除空端口
|
||||||
|
func removeEmptyPort(host string) string {
|
||||||
|
if hasPort(host) {
|
||||||
|
return strings.TrimSuffix(host, ":")
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// UrlEncode URL 编码
|
||||||
|
func UrlEncode(str string) string {
|
||||||
|
return url.QueryEscape(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UrlEncodeRaw URL 编码(空格编码为 %20)
|
||||||
|
func UrlEncodeRaw(str string) string {
|
||||||
|
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UrlDecode URL 解码
|
||||||
|
func UrlDecode(str string) (string, error) {
|
||||||
|
return url.QueryUnescape(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildQuery 构建查询字符串
|
||||||
|
func BuildQuery(data map[string]string) string {
|
||||||
|
query := url.Values{}
|
||||||
|
for k, v := range data {
|
||||||
|
query.Add(k, v)
|
||||||
|
}
|
||||||
|
return query.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildPostForm 构建 POST 表单数据
|
||||||
|
func BuildPostForm(data map[string]string) []byte {
|
||||||
|
return []byte(BuildQuery(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneHeader 克隆 Header
|
||||||
|
func cloneHeader(h http.Header) http.Header {
|
||||||
|
if h == nil {
|
||||||
|
return make(http.Header)
|
||||||
|
}
|
||||||
|
newHeader := make(http.Header, len(h))
|
||||||
|
for k, v := range h {
|
||||||
|
newHeader[k] = append([]string(nil), v...)
|
||||||
|
}
|
||||||
|
return newHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneCookies 克隆 Cookies
|
||||||
|
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
|
||||||
|
if cookies == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
newCookies := make([]*http.Cookie, len(cookies))
|
||||||
|
for i, c := range cookies {
|
||||||
|
newCookies[i] = cloneCookie(c)
|
||||||
|
}
|
||||||
|
return newCookies
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneCookie(cookie *http.Cookie) *http.Cookie {
|
||||||
|
if cookie == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: cookie.Name,
|
||||||
|
Value: cookie.Value,
|
||||||
|
Path: cookie.Path,
|
||||||
|
Domain: cookie.Domain,
|
||||||
|
Expires: cookie.Expires,
|
||||||
|
RawExpires: cookie.RawExpires,
|
||||||
|
MaxAge: cookie.MaxAge,
|
||||||
|
Secure: cookie.Secure,
|
||||||
|
HttpOnly: cookie.HttpOnly,
|
||||||
|
SameSite: cookie.SameSite,
|
||||||
|
Raw: cookie.Raw,
|
||||||
|
Unparsed: append([]string(nil), cookie.Unparsed...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneStringMapSlice 克隆 map[string][]string
|
||||||
|
func cloneStringMapSlice(m map[string][]string) map[string][]string {
|
||||||
|
if m == nil {
|
||||||
|
return make(map[string][]string)
|
||||||
|
}
|
||||||
|
newMap := make(map[string][]string, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
newMap[k] = append([]string(nil), v...)
|
||||||
|
}
|
||||||
|
return newMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneBytes 克隆字节切片
|
||||||
|
func cloneBytes(b []byte) []byte {
|
||||||
|
if b == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
newBytes := make([]byte, len(b))
|
||||||
|
copy(newBytes, b)
|
||||||
|
return newBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneStringSlice 克隆字符串切片
|
||||||
|
func cloneStringSlice(s []string) []string {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
newSlice := make([]string, len(s))
|
||||||
|
copy(newSlice, s)
|
||||||
|
return newSlice
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneFiles 克隆文件列表
|
||||||
|
func cloneFiles(files []RequestFile) []RequestFile {
|
||||||
|
if files == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
newFiles := make([]RequestFile, len(files))
|
||||||
|
copy(newFiles, files)
|
||||||
|
return newFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneTLSConfig 克隆 TLS 配置
|
||||||
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cfg.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyWithProgress 带进度的复制
|
||||||
|
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
var written int64
|
||||||
|
buf := make([]byte, 32*1024) // 32KB buffer
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return written, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
nr, err := src.Read(buf)
|
||||||
|
if nr > 0 {
|
||||||
|
nw, ew := dst.Write(buf[:nr])
|
||||||
|
if nw > 0 {
|
||||||
|
written += int64(nw)
|
||||||
|
if progress != nil {
|
||||||
|
// 同步调用进度回调(不使用 goroutine)
|
||||||
|
progress(filename, written, total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ew != nil {
|
||||||
|
return written, ew
|
||||||
|
}
|
||||||
|
if nr != nw {
|
||||||
|
return written, io.ErrShortWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
if progress != nil {
|
||||||
|
// 最后一次进度回调
|
||||||
|
progress(filename, written, total)
|
||||||
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
+284
@@ -0,0 +1,284 @@
|
|||||||
|
package starnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUrlEncodeRaw(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic string with space",
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello%20world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters",
|
||||||
|
input: "hello world!@#$%^&*()_+-=~`",
|
||||||
|
expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chinese characters",
|
||||||
|
input: "你好世界",
|
||||||
|
expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := UrlEncodeRaw(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUrlEncode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "space encoded as plus",
|
||||||
|
input: "hello world",
|
||||||
|
expected: "hello+world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters",
|
||||||
|
input: "hello world!@#$%^&*()_+-=~`",
|
||||||
|
expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := UrlEncode(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUrlDecode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic decode",
|
||||||
|
input: "hello%20world",
|
||||||
|
expected: "hello world",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "plus to space",
|
||||||
|
input: "hello+world",
|
||||||
|
expected: "hello world",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters",
|
||||||
|
input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60",
|
||||||
|
expected: "hello world!@#$%^&*()_+-=~`",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid encoding",
|
||||||
|
input: "%zz",
|
||||||
|
expected: "",
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := UrlDecode(tt.input)
|
||||||
|
if tt.expectErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("UrlDecode(%q) expected error, got nil", tt.input)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildQuery(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input map[string]string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single parameter",
|
||||||
|
input: map[string]string{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
expected: "key=value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
input: map[string]string{},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil map",
|
||||||
|
input: nil,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := BuildQuery(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPostForm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input map[string]string
|
||||||
|
expected []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic form",
|
||||||
|
input: map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
},
|
||||||
|
expected: []byte("key1=value1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
input: map[string]string{},
|
||||||
|
expected: []byte(""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := BuildPostForm(tt.input)
|
||||||
|
if string(result) != string(tt.expected) {
|
||||||
|
t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidMethod(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"GET", "GET", true},
|
||||||
|
{"POST", "POST", true},
|
||||||
|
{"PUT", "PUT", true},
|
||||||
|
{"DELETE", "DELETE", true},
|
||||||
|
{"PATCH", "PATCH", true},
|
||||||
|
{"OPTIONS", "OPTIONS", true},
|
||||||
|
{"HEAD", "HEAD", true},
|
||||||
|
{"TRACE", "TRACE", true},
|
||||||
|
{"CONNECT", "CONNECT", true},
|
||||||
|
{"invalid with space", "GET POST", false},
|
||||||
|
{"invalid with special char", "GET<>", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := validMethod(tt.method)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneCookies_FullFields(t *testing.T) {
|
||||||
|
expire := time.Now().Add(2 * time.Hour)
|
||||||
|
|
||||||
|
src := []*http.Cookie{
|
||||||
|
{
|
||||||
|
Name: "sid",
|
||||||
|
Value: "abc123",
|
||||||
|
Path: "/",
|
||||||
|
Domain: "example.com",
|
||||||
|
Expires: expire,
|
||||||
|
RawExpires: expire.UTC().Format(time.RFC1123),
|
||||||
|
MaxAge: 3600,
|
||||||
|
Secure: true,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Raw: "sid=abc123; Path=/; HttpOnly",
|
||||||
|
Unparsed: []string{"Priority=High", "Partitioned"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := cloneCookies(src)
|
||||||
|
if got == nil || len(got) != 1 {
|
||||||
|
t.Fatalf("cloneCookies() len=%v; want 1", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 指针应不同(不是浅拷贝)
|
||||||
|
if got[0] == src[0] {
|
||||||
|
t.Fatal("cookie pointer should be different (deep copy expected)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 字段值应一致
|
||||||
|
s := src[0]
|
||||||
|
g := got[0]
|
||||||
|
if g.Name != s.Name ||
|
||||||
|
g.Value != s.Value ||
|
||||||
|
g.Path != s.Path ||
|
||||||
|
g.Domain != s.Domain ||
|
||||||
|
!g.Expires.Equal(s.Expires) ||
|
||||||
|
g.RawExpires != s.RawExpires ||
|
||||||
|
g.MaxAge != s.MaxAge ||
|
||||||
|
g.Secure != s.Secure ||
|
||||||
|
g.HttpOnly != s.HttpOnly ||
|
||||||
|
g.SameSite != s.SameSite ||
|
||||||
|
g.Raw != s.Raw {
|
||||||
|
t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unparsed 内容一致
|
||||||
|
if len(g.Unparsed) != len(s.Unparsed) {
|
||||||
|
t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed))
|
||||||
|
}
|
||||||
|
for i := range s.Unparsed {
|
||||||
|
if g.Unparsed[i] != s.Unparsed[i] {
|
||||||
|
t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 Unparsed 是深拷贝(修改源不影响目标)
|
||||||
|
src[0].Unparsed[0] = "Modified=Yes"
|
||||||
|
if got[0].Unparsed[0] == "Modified=Yes" {
|
||||||
|
t.Fatal("Unparsed should be deep-copied, but was affected by source mutation")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user