39 Commits

Author SHA1 Message Date
b612 b026953c74 fix(starnet): simplify default user agent 2026-05-24 12:08:21 +08:00
b612 2f4c7158cf feat: 增加请求级 trace 摘要与诊断能力
- 新增 TraceRecorder 和 TraceSummary,汇总 DNS、连接、TLS、写请求、首包等关键事件
  - 为请求执行链接入结构化 trace hooks,补充标准路径与动态路径的 TLS 元信息
  - 增加 Request.TraceSummary() 和 Response.TraceSummary(),提供请求级与响应级摘要快照
  - 修复共享 TraceRecorder 在 Client 默认选项、Clone 和请求复用场景下的状态串扰问题
  - 修复 Response.TraceSummary() 回读 Request 最近状态导致的非快照语义
  - 收口自定义 DialFunc 下的 TLS trace 元数据,避免伪造连接地址
  - 补充 trace 相关回归测试,覆盖 HTTPS、DNS/Connect、连接复用、共享 recorder、响应快照和自定义拨号场景
  - 更新 README,补充 trace、Host 与 TLSServerName 的行为说明
2026-04-20 17:54:43 +08:00
b612 732e81316c fix(starnet): 重构请求执行链路并补齐代理/重试/trace边界
- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题
  - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径
  - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界
  - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
2026-04-19 15:39:51 +08:00
b612 9ac9b65bc5 fix(starnet): 收紧 TLS ClientHello 嗅探并补齐边界测试
- 用轻量 ClientHello 解析替代假握手式 TLS 嗅探
  - 保留截断和 max-bytes 场景下的 TLS 分类与缓冲回放能力
  - 拒绝首个 record 完整但并非 ClientHello 的伪 TLS 流量
  - 为动态 TLS 配置选择透出更完整的 ClientHello 元数据
  - 拆分 TLS 初始化失败统计为 sniff/config/plain rejected
  - 补充正常、分片、截断、限长、伪 TLS 等回归测试
2026-03-27 12:05:23 +08:00
b612 b5bd7595a1 1. 优化ping功能
2. 新增重试机制
3. 优化错误处理逻辑
2026-03-19 16:42:45 +08:00
b612 4568e17f06 fix: 修复核心bug并完善API
- 修复NewRequest系列函数不返回opt错误的问题
- 修复prepare()幂等性问题,支持请求重试
- 修复defaultDialTLSFunc的ServerName解析错误
- 修复Client.Clone()并发安全问题
- 补齐Client.Trace/Connect方法
- 新增Request.HTTPClient/Client方法
- 增强NewSimpleRequest错误处理的健壮性
2026-03-10 19:55:37 +08:00
b612 1bb30514ec bug fix:tls自定义时,没有设置servername的问题 2026-03-08 21:38:45 +08:00
b612 50aef48d49 rewrite program 2026-03-08 20:19:40 +08:00
b612 0e2f91eee2 fix:使用Client时,设置的参数不生效 2025-10-14 10:08:53 +08:00
b612 b90c59d6e7 修改版本号 2025-08-21 21:40:29 +08:00
b612 4e154cc17b update benchmark 2025-08-21 21:37:21 +08:00
b612 67b0025f9c 更新content-length的默认处理方式 2025-08-21 19:17:19 +08:00
b612 c4fa62536a 为client新增部分函数 2025-08-21 15:32:19 +08:00
b612 260ceb90ed 重构http Client部分 2025-08-21 15:02:02 +08:00
b612 d260181adf update 2025-08-15 15:07:51 +08:00
b612 e3b7369e12 bug fix:nil pointer error 2025-08-13 10:16:08 +08:00
b612 4e17fee681 bug fix 2025-07-14 18:38:31 +08:00
b612 a8eed30db5 add http client control 2025-07-14 18:23:14 +08:00
b612 c1eaf43058 update 2025-06-17 12:36:57 +08:00
b612 9f5aca124d update 2025-06-17 12:09:12 +08:00
b612 54958724e7 bug fix 2025-06-13 17:16:38 +08:00
b612 7a17672149 update tls sniffer 2025-06-12 16:50:47 +08:00
b612 44b807d3d1 update 2025-06-06 15:43:38 +08:00
b612 0d847462b3 bug fix:nil pointer 2025-04-28 13:19:45 +08:00
b612 deed4207ea bug fix 2024-08-30 23:44:49 +08:00
b612 f6363fed07 move starqueue from starnet to stario 2024-08-18 17:18:52 +08:00
b612 1de78f2f06 rewrite curl.go 2024-08-08 22:03:10 +08:00
b612 d0122a9771 update go.mod & update que.go 2024-03-10 14:04:48 +08:00
b612 319518d71d update go mod 2023-02-11 17:17:03 +08:00
b612 be3df9703e update go mod and improve icmp result 2023-02-11 17:15:01 +08:00
b612 b92288bbc9 update go mod 2023-02-03 13:18:53 +08:00
b612 0805549006 add origin request/response http method 2022-09-06 15:15:33 +08:00
b612 033272f38a add tls config 2022-08-22 16:22:22 +08:00
b612 93b756d9fb add no auto redirect config 2022-06-06 11:18:42 +08:00
b612 d71eacdc91 optional function add 2022-03-14 15:43:56 +08:00
b612 747fc52c44 go mod support 2022-03-14 11:03:42 +08:00
b612 ce3ebbbf8a add bodyreader fn 2022-03-11 09:29:09 +08:00
b612 66c8abbcea bug fix 2021-11-12 15:56:23 +08:00
Starainrt b4bffa978c add ping functions 2021-06-04 10:49:23 +08:00
78 changed files with 17815 additions and 562 deletions
+6
View File
@@ -0,0 +1,6 @@
.idea
.sentrux/
agent_readme.md
target.md
agents.md
.codex
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2026 starnet contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+122
View File
@@ -0,0 +1,122 @@
# starnet
`starnet` 是一个面向 Go 的网络工具库,提供 HTTP 请求控制、TLS 嗅探和 ICMP Ping 能力。
## 功能概览
- 基于 `context` 的请求级超时控制,不修改共享 `http.Client` 的全局超时
- 请求级网络控制:代理、自定义 IP / DNS、拨号超时、TLS 配置
- 支持请求级 `Host` 覆盖、显式 `TLSServerName/SNI` 控制,以及结构化 trace 回调 / 摘要
- 内置重试机制,支持重试次数、退避、抖动、状态码白名单和自定义错误判定
- 响应体大小限制,避免一次性读取过大内容
- 错误分类辅助:`ClassifyError``IsTimeout``IsDNS``IsTLS``IsProxy``IsCanceled`
- TLS 嗅探监听 / 拨号工具,适用于 TLS 与明文混合场景
- ICMP Ping,支持 IPv4 / IPv6 目标和选项化探测
## 主要能力
### HTTP 客户端与请求构建
- 同时提供 `WithXxx` 选项和 `SetXxx` 链式调用两套接口
- 支持 `Get``Post``Put``Delete``Head``Patch``Options``Trace``Connect`
- 支持 JSON、表单、`multipart/form-data`、流式请求体等常见请求体形态
- 支持显式 `Host` 覆盖与 `TLSServerName` 设置,便于直连 IP、虚拟主机和证书校验场景分离控制
- Header、Cookie、Query 等输入在关键路径上做防御性拷贝,降低外部可变状态污染风险
- `Request.Clone()` 可用于并发场景或同一基础请求的变体构造
### Trace 与诊断
- 支持 `TraceHooks`,可接收 DNS、建连、TLS 握手、写请求、首包等结构化事件
- 支持 `TraceRecorder` / `TraceSummary`,用于汇总一次请求的关键网络过程和 TLS 摘要
- `Request.TraceSummary()` 返回该请求最近一次执行的摘要快照,`Response.TraceSummary()` 返回当前响应对应的摘要快照
- 若多个请求共享同一个 `TraceRecorder`,其 `Summary()` 表示最近一次完成请求的摘要
### 超时与重试
- 请求超时通过 `context` 截止时间控制,不污染共享客户端配置
- 重试支持:
- 最大尝试次数
- 基础退避、最大退避和退避因子
- 抖动比例
- 可重试状态码集合
- 仅幂等方法重试
- 自定义错误判定函数
- 重试成功后返回的 `Response` 仍保持对原始 `Request` 的引用
### 响应处理
- 提供 `Bytes``String``JSON``Reader` 等响应体读取接口
- 支持自动预取响应体
- 支持按字节数限制响应体读取上限
### Ping 模块
- 提供 `Ping``PingWithContext``Pingable` 以及兼容函数 `IsIpPingable`
- `PingOptions` 支持次数、超时、间隔、截止时间、地址族偏好、源地址、负载长度等参数
- 对权限不足、协议不支持、超时、解析失败等情况提供明确错误语义
## 安装
```bash
go get b612.me/starnet
```
## 快速示例
```go
package main
import (
"fmt"
"net/http"
"time"
"b612.me/starnet"
)
func main() {
resp, err := starnet.Get(
"https://example.com",
starnet.WithTimeout(2*time.Second),
starnet.WithRetry(2,
starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2),
starnet.WithRetryJitter(0.1),
),
starnet.WithMaxRespBodyBytes(1<<20),
)
if err != nil {
fmt.Println("request failed:", starnet.ClassifyError(err), err)
return
}
defer resp.Close()
fmt.Println("status:", resp.StatusCode)
_, _ = resp.Body().Bytes()
ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{
Count: 2,
Timeout: 2 * time.Second,
})
fmt.Println("pingable:", ok, pingErr == nil)
_ = http.MethodGet
}
```
## 行为说明
- `NewClient``NewRequest` 以及请求构造相关接口在遇到非法选项时会直接返回错误,例如格式不合法的代理地址。
- `NewClientNoErr` 是便利构造函数;如果选项校验失败,仍可能返回一个占位 `Client`,需要严格校验配置时应优先使用 `NewClient`
- `SetHost` / `WithHost` 只覆盖 HTTP 请求的 `Host`;如需单独控制 TLS SNI 或证书校验名,应配合 `SetTLSServerName` / `WithTLSServerName` 使用。
- 重试默认仅对幂等方法生效。即使显式关闭“仅幂等方法重试”,通过 `SetBodyReader``WithBodyReader` 构造的请求在非幂等方法上仍不会自动重试。
- 当同时使用 `proxy + custom IP/DNS` 且解析出多个目标地址时,自动目标回退仅对幂等请求生效,以避免重复写入。
- 绑定到请求上的 `TraceRecorder` 用于发布已完成请求的摘要;请求执行中的中间状态不保证通过共享 recorder 实时可见。
## 稳定性说明
- 原始 ICMP Ping 在部分系统上需要额外权限。
- 依赖外部网络环境的集成测试结果可能受运行环境影响。
## 许可证
本项目采用 Apache License 2.0,详见 [LICENSE](./LICENSE)。
+1657
View File
File diff suppressed because it is too large Load Diff
+224
View File
@@ -0,0 +1,224 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func BenchmarkGetRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkGetRequestWithHeaders(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL,
WithHeader("X-Custom", "value"),
WithUserAgent("BenchmarkAgent"))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkPostRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
testData := []byte("test data for benchmark")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithBody(testData))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkJSONRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type TestData struct {
Name string `json:"name"`
Value int `json:"value"`
}
data := TestData{Name: "test", Value: 123}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithJSON(data))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
var result map[string]string
resp.Body().JSON(&result)
resp.Close()
}
}
func BenchmarkConcurrentRequests(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
})
}
func BenchmarkRequestClone(b *testing.B) {
req := NewSimpleRequest("https://example.com", "GET").
SetHeader("X-Custom", "value").
AddQuery("key", "value")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = req.Clone()
}
}
func BenchmarkClientCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewClientNoErr()
}
}
func BenchmarkRequestCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewSimpleRequest("https://example.com", "GET")
}
}
func BenchmarkRequestPrepareDefaultPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET")
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkRequestPrepareDynamicPath(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := NewSimpleRequest("https://example.com", "GET",
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err := req.prepare(); err != nil {
b.Fatalf("prepare() error: %v", err)
}
}
}
func BenchmarkResponseBodyRead(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test response data"))
}))
defer server.Close()
// Pre-fetch response
resp, _ := Get(server.URL, WithAutoFetch(true))
defer resp.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = resp.Body().String()
}
}
func BenchmarkDifferentResponseSizes(b *testing.B) {
sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
for _, size := range sizes {
responseData := make([]byte, size)
for i := 0; i < size; i++ {
responseData[i] = 'A'
}
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(responseData)
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().Bytes()
resp.Close()
}
})
}
}
+145
View File
@@ -0,0 +1,145 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRequestBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := []byte("test data")
req := NewSimpleRequest(server.URL, "POST").SetBody(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().Bytes()
if !bytes.Equal(body, testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyString(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test string data"
req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyReader(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test reader data"
reader := strings.NewReader(testData)
req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Type") != ContentTypeJSON {
t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON)
}
var data map[string]string
json.NewDecoder(r.Body).Decode(&data)
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
testData := map[string]string{
"name": "John",
"email": "john@example.com",
}
req := NewSimpleRequest(server.URL, "POST").SetJSON(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != testData["name"] {
t.Errorf("name = %v; want %v", result["name"], testData["name"])
}
}
func TestRequestFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
data := make(map[string]string)
for k, v := range r.Form {
if len(v) > 0 {
data[k] = v[0]
}
}
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("email", "john@example.com")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != "John" {
t.Errorf("name = %v; want John", result["name"])
}
if result["email"] != "john@example.com" {
t.Errorf("email = %v; want john@example.com", result["email"])
}
}
+349
View File
@@ -0,0 +1,349 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
)
// Client HTTP 客户端封装
type Client struct {
client *http.Client
opts []RequestOpt
mu sync.RWMutex
}
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := newBaseHTTPTransport()
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
//Timeout: DefaultTimeout,
}
// 应用选项(如果有)
if len(opts) > 0 {
// 创建临时请求以应用选项
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
if err != nil {
return nil, wrapError(err, "create client")
}
if req.err != nil {
return nil, wrapError(req.err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
if req.config.Network.Timeout > 0 {
httpClient.Timeout = req.config.Network.Timeout
}
*/
// 如果有自定义 Transport
if req.config.CustomTransport && req.config.Transport != nil {
httpClient.Transport = &Transport{base: req.config.Transport}
}
}
return &Client{
client: httpClient,
opts: opts,
}, nil
}
// NewClientNoErr 创建新的 Client(忽略错误)。
// 当 opts 校验失败时,它仍会返回一个可用的 Client 占位对象;
// 如果调用方需要感知选项错误或依赖默认 starnet Transport 行为,应优先使用 NewClient。
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
client = &Client{
client: &http.Client{},
opts: opts,
}
}
return client
}
// NewClientFromHTTP 从 http.Client 创建 Client
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
if httpClient == nil {
return nil, ErrNilClient
}
// 确保 Transport 是我们的自定义类型
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
// 已经是我们的类型
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
// 包装标准 Transport
httpClient.Transport = &Transport{
base: t,
}
default:
return nil, fmt.Errorf("unsupported transport type: %T", t)
}
}
return &Client{
client: httpClient,
}, nil
}
// HTTPClient 获取底层 http.Client
func (c *Client) HTTPClient() *http.Client {
return c.client
}
// RequestOptions 获取默认选项(返回副本)
func (c *Client) RequestOptions() []RequestOpt {
c.mu.RLock()
defer c.mu.RUnlock()
opts := make([]RequestOpt, len(c.opts))
copy(opts, c.opts)
return opts
}
// SetOptions 设置默认选项
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = opts
c.mu.Unlock()
return c
}
// AddOptions 追加默认选项
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = append(c.opts, opts...)
c.mu.Unlock()
return c
}
// Clone 克隆 Client(深拷贝)
func (c *Client) Clone() *Client {
c.mu.RLock()
defer c.mu.RUnlock()
// 克隆 Transport
var transport http.RoundTripper
if c.client.Transport != nil {
switch t := c.client.Transport.(type) {
case *Transport:
transport = &Transport{
base: t.base.Clone(),
}
case *http.Transport:
transport = t.Clone()
default:
transport = c.client.Transport
}
}
return &Client{
client: &http.Client{
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
},
opts: append([]RequestOpt(nil), c.opts...),
}
}
// SetDefaultTLSConfig 设置默认 TLS 配置
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
transport.ensureBaseLocked()
if tlsConfig != nil {
transport.base.TLSClientConfig = tlsConfig.Clone()
} else {
transport.base.TLSClientConfig = nil
}
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock()
}
return c
}
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
transport.ensureBaseLocked()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
} else {
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.resetDynamicTransportCacheLocked()
transport.mu.Unlock()
}
return c
}
// DisableRedirect 禁用重定向
func (c *Client) DisableRedirect() *Client {
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return c
}
// EnableRedirect 启用重定向
func (c *Client) EnableRedirect() *Client {
c.client.CheckRedirect = nil
return c
}
// NewRequest 创建新请求
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
return c.NewRequestWithContext(context.Background(), url, method, opts...)
}
// NewRequestWithContext 创建新请求(带 context
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
// 合并 Client 级别和请求级别的选项
c.mu.RLock()
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
c.mu.RUnlock()
req, err := newRequest(ctx, url, method, allOpts...)
if err != nil {
return nil, err
}
if req.err != nil {
return nil, req.err
}
req.client = c
req.httpClient = c.client
return req, nil
}
// Get 发送 GET 请求
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodGet, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Post 发送 POST 请求
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPost, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Put 发送 PUT 请求
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPut, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Delete 发送 DELETE 请求
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodDelete, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Head 发送 HEAD 请求
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodHead, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Patch 发送 PATCH 请求
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPatch, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Options 发送 OPTIONS 请求
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodOptions, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
}
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
if err != nil {
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
client: c,
httpClient: c.client,
autoFetch: DefaultFetchRespBody,
}
}
return req
}
// Trace 发送 TRACE 请求
func (c *Client) Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求
func (c *Client) Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
+223
View File
@@ -0,0 +1,223 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
if client == nil {
t.Fatal("NewClient() returned nil")
}
}
func TestNewClientNoErr(t *testing.T) {
client := NewClientNoErr()
if client == nil {
t.Fatal("NewClientNoErr() returned nil")
}
}
func TestNewClientFromHTTP(t *testing.T) {
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
client, err := NewClientFromHTTP(httpClient)
if err != nil {
t.Fatalf("NewClientFromHTTP() error: %v", err)
}
if client == nil {
t.Fatal("NewClientFromHTTP() returned nil")
}
// Test with nil client
_, err = NewClientFromHTTP(nil)
if err == nil {
t.Error("NewClientFromHTTP(nil) should return error")
}
}
func TestClientOptions(t *testing.T) {
client := NewClientNoErr()
// Set options
client.SetOptions(WithTimeout(5 * time.Second))
opts := client.RequestOptions()
if len(opts) != 1 {
t.Errorf("RequestOptions() length = %v; want 1", len(opts))
}
// Add options
client.AddOptions(WithUserAgent("TestAgent"))
opts = client.RequestOptions()
if len(opts) != 2 {
t.Errorf("RequestOptions() length = %v; want 2", len(opts))
}
}
func TestClientClone(t *testing.T) {
client := NewClientNoErr(WithTimeout(5 * time.Second))
cloned := client.Clone()
if cloned == nil {
t.Fatal("Clone() returned nil")
}
// 修改克隆的 client
cloned.SetOptions(WithTimeout(10 * time.Second))
origOpts := client.RequestOptions()
clonedOpts := cloned.RequestOptions()
// 原 client 应该还是 1 个选项
if len(origOpts) != 1 {
t.Errorf("Original client options = %v; want 1", len(origOpts))
}
// 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖)
if len(clonedOpts) != 1 {
t.Errorf("Cloned client options = %v; want 1", len(clonedOpts))
}
}
func TestClientHTTPMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
client := NewClientNoErr()
tests := []struct {
name string
method func(string, ...RequestOpt) (*Response, error)
want string
}{
{"GET", client.Get, "GET"},
{"POST", client.Post, "POST"},
{"PUT", client.Put, "PUT"},
{"DELETE", client.Delete, "DELETE"},
{"PATCH", client.Patch, "PATCH"},
{"HEAD", client.Head, "HEAD"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.method(server.URL)
if err != nil {
t.Fatalf("%s() error: %v", tt.name, err)
}
defer resp.Close()
if tt.want != "HEAD" {
body, _ := resp.Body().String()
if body != tt.want {
t.Errorf("Body = %v; want %v", body, tt.want)
}
}
})
}
}
func TestClientRedirect(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 2 {
redirectCount++
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("final"))
}))
defer server.Close()
// Test with redirect enabled (default)
client := NewClientNoErr()
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
resp.Close()
if redirectCount != 2 {
t.Errorf("Redirect count = %v; want 2", redirectCount)
}
// Test with redirect disabled
redirectCount = 0
client.DisableRedirect()
resp2, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusFound {
t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound)
}
}
func TestClientTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail with self-signed cert)
client := NewClientNoErr()
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected TLS error with self-signed cert, got nil")
}
// With skip verify
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() with skip verify error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientNewSimpleRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v"))
if req == nil {
t.Fatal("NewSimpleRequest returned nil")
}
if req.Err() != nil {
t.Fatalf("NewSimpleRequest err: %v", req.Err())
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}
+111
View File
@@ -0,0 +1,111 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestConcurrentRequests(t *testing.T) {
var counter int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&counter, 1)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
concurrency := 100
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
wg.Wait()
if atomic.LoadInt64(&counter) != int64(concurrency) {
t.Errorf("counter = %v; want %v", counter, concurrency)
}
}
func TestConcurrentClientModification(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
var wg sync.WaitGroup
wg.Add(200)
// 100 goroutines reading
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
// 100 goroutines modifying options
for i := 0; i < 100; i++ {
go func(i int) {
defer wg.Done()
if i%2 == 0 {
client.AddOptions(WithTimeout(5 * time.Second))
} else {
_ = client.RequestOptions()
}
}(i)
}
wg.Wait()
}
func TestConcurrentRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value")
var wg sync.WaitGroup
wg.Add(50)
for i := 0; i < 50; i++ {
go func(i int) {
defer wg.Done()
cloned := baseReq.Clone()
// 修复:使用有效的 header 值
cloned.SetHeader("X-Index", fmt.Sprintf("%d", i))
resp, err := cloned.Do()
if err != nil {
t.Errorf("Do() error: %v", err)
return
}
resp.Close()
}(i)
}
wg.Wait()
}
+188
View File
@@ -0,0 +1,188 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyTLSConfigCacheable
ctxKeyTLSServerName
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
ctxKeyRequestContext
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
TLSConfigCacheable bool
TLSServerName string
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
var emptyRequestContext = &RequestContext{}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
return rc
}
}
var rc *RequestContext
ensure := func() *RequestContext {
if rc == nil {
rc = &RequestContext{}
}
return rc
}
if v := ctx.Value(ctxKeyTransport); v != nil {
ensure().Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
ensure().TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyTLSConfigCacheable); v != nil {
ensure().TLSConfigCacheable, _ = v.(bool)
}
if v := ctx.Value(ctxKeyTLSServerName); v != nil {
ensure().TLSServerName, _ = v.(string)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
ensure().Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
ensure().CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
ensure().CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
ensure().DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
ensure().Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
ensure().LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
ensure().DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
if rc == nil {
return emptyRequestContext
}
return rc
}
func cloneRequestContext(rc *RequestContext) *RequestContext {
if rc == nil {
return nil
}
cloned := *rc
cloned.CustomIP = cloneStringSlice(rc.CustomIP)
cloned.CustomDNS = cloneStringSlice(rc.CustomDNS)
return &cloned
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
if rc == nil {
return false
}
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
func buildRequestContext(config *RequestConfig, defaultTLSServerName string) *RequestContext {
if config == nil {
return nil
}
rc := &RequestContext{
DialTimeout: config.Network.DialTimeout,
Timeout: config.Network.Timeout,
}
// 处理 TLS 配置
var tlsConfig *tls.Config
tlsConfigCacheable := false
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
} else if config.TLS.SkipVerify || config.TLS.ServerName != "" {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
}
tlsConfigCacheable = true
}
if config.TLS.SkipVerify && tlsConfig != nil {
tlsConfig.InsecureSkipVerify = true
}
if config.TLS.ServerName != "" && tlsConfig != nil {
tlsConfig.ServerName = config.TLS.ServerName
}
if tlsConfig != nil {
rc.TLSConfig = tlsConfig
rc.TLSConfigCacheable = tlsConfigCacheable
}
if config.TLS.ServerName != "" {
rc.TLSServerName = config.TLS.ServerName
} else if defaultTLSServerName != "" {
rc.TLSServerName = defaultTLSServerName
}
rc.Proxy = config.Network.Proxy
rc.CustomIP = cloneStringSlice(config.DNS.CustomIP)
rc.CustomDNS = cloneStringSlice(config.DNS.CustomDNS)
rc.LookupIPFn = config.DNS.LookupFunc
rc.DialFn = config.Network.DialFunc
if config.CustomTransport && config.Transport != nil {
rc.Transport = config.Transport
}
if !needsDynamicTransport(rc) {
return nil
}
return rc
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig, defaultTLSServerName string) context.Context {
rc := buildRequestContext(config, defaultTLSServerName)
if rc == nil {
return ctx
}
return context.WithValue(ctx, ctxKeyRequestContext, rc)
}
-267
View File
@@ -1,267 +0,0 @@
package starnet
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"b612.me/stario"
)
const (
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
HEADER_FORM_DATA = `multipart/form-data`
HEADER_JSON = `application/json`
)
type RequestFile struct {
UploadFile string
UploadForm map[string]string
UploadName string
}
type Request struct {
TimeOut int
DialTimeOut int
Url string
Method string
RecvData []byte
RecvContentLength int64
WriteRecvData bool
RecvIo io.Writer
ReqHeader http.Header
ReqCookies []*http.Cookie
RespHeader http.Header
RespCookies []*http.Cookie
RequestFile
RespHttpCode int
PostBuffer *bytes.Buffer
CircleBuffer *stario.StarBuffer
Proxy string
Process func(float64)
}
func NewRequests(url string, postdata []byte, method string) Request {
req := Request{
TimeOut: 30,
DialTimeOut: 15,
Url: url,
PostBuffer: bytes.NewBuffer(postdata),
Method: method,
WriteRecvData: true,
}
req.ReqHeader = make(http.Header)
if strings.ToUpper(method) == "POST" {
req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE)
}
req.ReqHeader.Set("User-Agent", "B612 / 1.0.0")
return req
}
func (curl *Request) ResetReqHeader() {
curl.ReqHeader = make(http.Header)
}
func (curl *Request) ResetReqCookies() {
curl.ReqCookies = []*http.Cookie{}
}
func (curl *Request) AddSimpleCookie(key, value string) {
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"})
}
func randomBoundary() string {
var buf [30]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
func Curl(curl Request) (resps Request, err error) {
var fpsrc *os.File
if curl.RequestFile.UploadFile != "" {
fpsrc, err = os.Open(curl.UploadFile)
if err != nil {
return
}
defer fpsrc.Close()
boundary := randomBoundary()
boundarybytes := []byte("\r\n--" + boundary + "\r\n")
endbytes := []byte("\r\n--" + boundary + "--\r\n")
fpstat, _ := fpsrc.Stat()
filebig := float64(fpstat.Size())
sum, n := 0, 0
fpdst := stario.NewStarBuffer(1048576)
if curl.UploadForm != nil {
for k, v := range curl.UploadForm {
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\";\r\nContent-Type: x-www-form-urlencoded \r\n\r\n", k)
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
fpdst.Write([]byte(v))
}
}
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n", curl.UploadName, fpstat.Name())
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
go func() {
for {
bufs := make([]byte, 393213)
n, err = fpsrc.Read(bufs)
if err != nil {
if err == io.EOF {
if n != 0 {
fpdst.Write(bufs[0:n])
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
}
break
}
return
}
sum += n
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
fpdst.Write(bufs[0:n])
}
fpdst.Write(endbytes)
fpdst.Write(nil)
}()
curl.CircleBuffer = fpdst
curl.ReqHeader.Set("Content-Type", "multipart/form-data;boundary="+boundary)
}
resp, err := netcurl(curl)
if err != nil {
return Request{}, err
}
defer resp.Body.Close()
curl.PostBuffer = nil
curl.CircleBuffer = nil
curl.RespHttpCode = resp.StatusCode
curl.RespHeader = resp.Header
curl.RespCookies = resp.Cookies()
curl.RecvContentLength = resp.ContentLength
readFunc := func(reader io.ReadCloser, writer io.Writer) error {
lengthall := resp.ContentLength
defer reader.Close()
var lengthsum int
buf := make([]byte, 65535)
for {
n, err := reader.Read(buf)
if n != 0 {
_, err := writer.Write(buf[:n])
lengthsum += n
if curl.Process != nil {
go curl.Process(float64(lengthsum) / float64(lengthall) * 100.00)
}
if err != nil {
return err
}
}
if err != nil && err != io.EOF {
return err
} else if err == io.EOF {
return nil
}
}
}
if curl.WriteRecvData {
buf := bytes.NewBuffer([]byte{})
err = readFunc(resp.Body, buf)
if err != nil {
return
}
curl.RecvData = buf.Bytes()
}
if curl.RecvIo != nil {
if curl.WriteRecvData {
_, err = curl.RecvIo.Write(curl.RecvData)
} else {
err = readFunc(resp.Body, curl.RecvIo)
if err != nil {
return
}
}
}
return curl, err
}
func netcurl(curl Request) (*http.Response, error) {
var req *http.Request
var err error
if curl.Method == "" {
return nil, errors.New("Error Method Not Entered")
}
if curl.PostBuffer != nil && curl.PostBuffer.Len() > 0 {
req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer)
} else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 {
req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer)
} else {
req, err = http.NewRequest(curl.Method, curl.Url, nil)
}
if err != nil {
return nil, err
}
req.Header = curl.ReqHeader
if len(curl.ReqCookies) != 0 {
for _, v := range curl.ReqCookies {
req.AddCookie(v)
}
}
transport := &http.Transport{
Dial: func(netw, addr string) (net.Conn, error) {
deadline := time.Now().Add(time.Duration(curl.TimeOut) * time.Second)
c, err := net.DialTimeout(netw, addr, time.Second*time.Duration(curl.DialTimeOut))
if err != nil {
return nil, err
}
if curl.TimeOut != 0 {
c.SetDeadline(deadline)
}
return c, nil
},
}
if curl.Proxy != "" {
purl, err := url.Parse(curl.Proxy)
if err != nil {
return nil, err
}
transport.Proxy = http.ProxyURL(purl)
}
client := &http.Client{
Transport: transport,
}
resp, err := client.Do(req)
return resp, err
}
func UrlEncodeRaw(str string) string {
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
return strs
}
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
func Build_Query(queryData map[string]string) string {
query := url.Values{}
for k, v := range queryData {
query.Add(k, v)
}
return query.Encode()
}
+147
View File
@@ -0,0 +1,147 @@
package starnet
import (
"net/http"
"sync"
"time"
)
var (
defaultClient *Client
defaultHTTPClient *http.Client
defaultClientOnce sync.Once
defaultHTTPOnce sync.Once
defaultMu sync.RWMutex
)
// DefaultClient 获取默认 Client(单例)
func DefaultClient() *Client {
defaultMu.RLock()
if defaultClient != nil {
c := defaultClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultClientOnce.Do(func() {
c := NewClientNoErr()
defaultMu.Lock()
defaultClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultClient
defaultMu.RUnlock()
return c
}
// DefaultHTTPClient 获取默认 http.Client(单例)
func DefaultHTTPClient() *http.Client {
defaultMu.RLock()
if defaultHTTPClient != nil {
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultHTTPOnce.Do(func() {
c := &http.Client{
Transport: &Transport{
base: &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
},
Timeout: 0, // 由请求级控制超时
}
defaultMu.Lock()
defaultHTTPClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
// SetDefaultClient 设置默认 Client
func SetDefaultClient(client *Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultClient = client
// 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖
defaultClientOnce.Do(func() {})
}
// SetDefaultHTTPClient 设置默认 http.Client
func SetDefaultHTTPClient(client *http.Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultHTTPClient = client
// 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖
defaultHTTPOnce.Do(func() {})
}
// Get 发送 GET 请求(使用默认 Client)
func Get(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Get(url, opts...)
}
// Post 发送 POST 请求(使用默认 Client
func Post(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Post(url, opts...)
}
// Put 发送 PUT 请求(使用默认 Client)
func Put(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Put(url, opts...)
}
// Delete 发送 DELETE 请求(使用默认 Client
func Delete(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Delete(url, opts...)
}
// Head 发送 HEAD 请求(使用默认 Client
func Head(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Head(url, opts...)
}
// Patch 发送 PATCH 请求(使用默认 Client
func Patch(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Patch(url, opts...)
}
// Options 发送 OPTIONS 请求(使用默认 Client
func Options(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Options(url, opts...)
}
// Trace 发送 TRACE 请求(使用默认 Client
func Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求(使用默认 Client
func Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
+111
View File
@@ -0,0 +1,111 @@
package starnet
import (
"net/http"
"testing"
)
func TestWithRawRequestNil(t *testing.T) {
_, err := NewRequest("http://example.com", "GET", WithRawRequest(nil))
if err == nil {
t.Fatal("expected error when WithRawRequest(nil)")
}
}
func TestSetHeadersDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
headers := http.Header{
"X-Test": []string{"v1"},
}
req.SetHeaders(headers)
headers.Set("X-Test", "v2")
if got := req.GetHeader("X-Test"); got != "v1" {
t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1")
}
}
func TestSetQueriesDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
queries := map[string][]string{
"k": []string{"v1"},
}
req.SetQueries(queries)
queries["k"][0] = "v2"
queries["k"] = append(queries["k"], "v3")
got := req.config.Queries["k"]
if len(got) != 1 || got[0] != "v1" {
t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got)
}
}
func TestSetFormDataDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "POST")
form := map[string][]string{
"name": []string{"alice"},
}
req.SetFormData(form)
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}
func TestWithBodyDefensiveCopy(t *testing.T) {
body := []byte("hello")
req, err := NewRequest("http://example.com", "POST", WithBody(body))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
body[0] = 'j'
if string(req.config.Body.Bytes) != "hello" {
t.Fatalf("body mutated by external slice change: got=%q want=%q", string(req.config.Body.Bytes), "hello")
}
}
func TestWithFormDataDefensiveCopy(t *testing.T) {
form := map[string][]string{
"name": []string{"alice"},
}
req, err := NewRequest("http://example.com", "POST", WithFormData(form))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}
func TestSetCustomIPDefensiveCopy(t *testing.T) {
ips := []string{"1.1.1.1", "8.8.8.8"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomIP(ips)
ips[0] = "9.9.9.9"
if got := req.config.DNS.CustomIP[0]; got != "1.1.1.1" {
t.Fatalf("custom ip mutated by external slice change: got=%q want=%q", got, "1.1.1.1")
}
}
func TestSetCustomDNSDefensiveCopy(t *testing.T) {
servers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").SetCustomDNS(servers)
servers[0] = "9.9.9.9"
if got := req.config.DNS.CustomDNS[0]; got != "8.8.8.8" {
t.Fatalf("custom dns mutated by external slice change: got=%q want=%q", got, "8.8.8.8")
}
}
+238
View File
@@ -0,0 +1,238 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"
)
func traceDNSLookup(traceState *traceState, host string, lookup func() ([]net.IPAddr, error)) ([]net.IPAddr, error) {
if traceState != nil {
traceState.beginManualDNS()
defer traceState.endManualDNS()
traceState.dnsStart(TraceDNSStartInfo{Host: host})
}
ipAddrs, err := lookup()
if traceState != nil {
traceState.dnsDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), ipAddrs...),
Err: err,
})
}
return ipAddrs, err
}
func resolveDialAddresses(ctx context.Context, reqCtx *RequestContext, host, port string, traceState *traceState) ([]string, error) {
if reqCtx == nil {
reqCtx = &RequestContext{}
}
var addrs []string
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, joinResolvedHostPort(ip, port))
}
return addrs, nil
}
var (
ipAddrs []net.IPAddr
err error
)
if reqCtx.LookupIPFn != nil {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return reqCtx.LookupIPFn(ctx, host)
})
} else if len(reqCtx.CustomDNS) > 0 {
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return resolver.LookupIPAddr(ctx, host)
})
} else {
ipAddrs, err = traceDNSLookup(traceState, host, func() ([]net.IPAddr, error) {
return net.DefaultResolver.LookupIPAddr(ctx, host)
})
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, joinResolvedHostPort(ipAddr.String(), port))
}
return addrs, nil
}
func joinResolvedHostPort(host, port string) string {
if port == "" {
if ip := net.ParseIP(host); ip != nil && ip.To4() == nil {
return "[" + host + "]"
}
return host
}
return net.JoinHostPort(host, port)
}
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS)
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
addrs, err := resolveDialAddresses(ctx, reqCtx, host, port, traceState)
if err != nil {
return nil, err
}
// 尝试连接所有地址
dialer := &net.Dialer{Timeout: dialTimeout}
var lastErr error
for _, addr := range addrs {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
lastErr = err
continue
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
traceState := getTraceState(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
serverName := tlsConfig.ServerName
if serverName == "" {
serverName = reqCtx.TLSServerName
}
if serverName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
host = addr[:idx]
} else {
host = addr
}
}
serverName = host
}
if serverName != "" && tlsConfig.ServerName != serverName {
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = serverName
}
if traceState != nil {
traceState.markCustomTLS()
traceState.tlsHandshakeStart(TraceTLSHandshakeStartInfo{
Network: network,
Addr: addr,
ServerName: serverName,
})
}
// 执行 TLS 握手
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
Err: err,
})
}
conn.Close()
return nil, wrapError(err, "tls handshake")
}
if traceState != nil {
traceState.tlsHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: network,
Addr: addr,
ServerName: serverName,
ConnectionState: tlsConn.ConnectionState(),
})
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/
+103
View File
@@ -0,0 +1,103 @@
package starnet
import (
"context"
"net"
"testing"
)
func TestRequestCustomIP(t *testing.T) {
customIPs := []string{"1.2.3.4", "5.6.7.8"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP(customIPs)
if len(req.config.DNS.CustomIP) != 2 {
t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP))
}
for i, ip := range req.config.DNS.CustomIP {
if ip != customIPs[i] {
t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i])
}
}
}
func TestRequestCustomIPInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"invalid-ip"})
if req.Err() == nil {
t.Error("Expected error for invalid IP, got nil")
}
}
func TestRequestCustomDNS(t *testing.T) {
dnsServers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS(dnsServers)
if len(req.config.DNS.CustomDNS) != 2 {
t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS))
}
}
func TestRequestCustomDNSInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS([]string{"invalid-dns"})
if req.Err() == nil {
t.Error("Expected error for invalid DNS, got nil")
}
}
func TestRequestLookupFunc(t *testing.T) {
called := false
lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) {
called = true
return []net.IPAddr{
{IP: net.ParseIP("1.2.3.4")},
}, nil
}
req := NewSimpleRequest("http://example.com", "GET").
SetLookupFunc(lookupFunc)
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc not set")
}
// Call the function to verify it works
ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com")
if err != nil {
t.Errorf("LookupFunc error: %v", err)
}
if !called {
t.Error("LookupFunc was not called")
}
if len(ips) != 1 {
t.Errorf("IPs length = %v; want 1", len(ips))
}
}
func TestDNSPriority(t *testing.T) {
// CustomIP should have highest priority
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"1.2.3.4"}).
SetCustomDNS([]string{"8.8.8.8"}).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil
})
// CustomIP should be set
if len(req.config.DNS.CustomIP) == 0 {
t.Error("CustomIP should be set")
}
// Others should also be set (but CustomIP takes priority in actual use)
if len(req.config.DNS.CustomDNS) == 0 {
t.Error("CustomDNS should be set")
}
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc should be set")
}
}
+144
View File
@@ -0,0 +1,144 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/url"
"testing"
)
func BenchmarkDynamicTransportCustomIP(b *testing.B) {
server := newIPv4Server(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := benchmarkTargetURL(b, server.URL, "bench-custom-ip.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL, WithCustomIP([]string{"127.0.0.1"}))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportProxyTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(b, nil)
defer proxy.Close()
targetURL := httpsURLForHost(b, server, "bench-proxy-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPTLSCacheable(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-custom-ip-cacheable.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithSkipTLSVerify(true),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func BenchmarkDynamicTransportCustomIPUserTLSConfig(b *testing.B) {
server := newIPv4TLSServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
targetURL := httpsURLForHost(b, server, "bench-user-tls.test")
client := NewClientNoErr()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := client.Get(targetURL,
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
_, _ = resp.Body().Bytes()
resp.Close()
}
}
func benchmarkTargetURL(tb testing.TB, rawURL, host string) string {
tb.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
tb.Fatalf("url.Parse() error: %v", err)
}
port := parsed.Port()
if port == "" {
switch parsed.Scheme {
case "https":
port = "443"
default:
port = "80"
}
}
return parsed.Scheme + "://" + host + ":" + port + pathWithQuery(parsed.Path, parsed.RawQuery)
}
func pathWithQuery(path, rawQuery string) string {
if path == "" {
path = "/"
}
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
+257
View File
@@ -0,0 +1,257 @@
package starnet
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"strings"
)
var (
// ErrInvalidMethod 无效的 HTTP 方法
ErrInvalidMethod = errors.New("starnet: invalid HTTP method")
// ErrInvalidURL 无效的 URL
ErrInvalidURL = errors.New("starnet: invalid URL")
// ErrInvalidIP 无效的 IP 地址
ErrInvalidIP = errors.New("starnet: invalid IP address")
// ErrInvalidDNS 无效的 DNS 服务器
ErrInvalidDNS = errors.New("starnet: invalid DNS server")
// ErrNilClient HTTP Client 为 nil
ErrNilClient = errors.New("starnet: http client is nil")
// ErrNilReader Reader 为 nil
ErrNilReader = errors.New("starnet: reader is nil")
// ErrFileNotFound 文件不存在
ErrFileNotFound = errors.New("starnet: file not found")
// ErrRequestNotPrepared 请求未准备好
ErrRequestNotPrepared = errors.New("starnet: request not prepared")
// ErrBodyAlreadyConsumed Body 已被消费
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
// ErrRespBodyTooLarge 响应体超过允许上限
ErrRespBodyTooLarge = errors.New("starnet: response body too large")
// ErrPingInvalidTimeout ping 超时参数无效
ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout")
// ErrPingPermissionDenied ping 需要更高权限(raw socket
ErrPingPermissionDenied = errors.New("starnet: ping permission denied")
// ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持
ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported")
// ErrPingNoResolvedTarget ping 目标无法解析为可用地址
ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved")
)
// wrapError 包装错误,添加上下文信息
func wrapError(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
msg := fmt.Sprintf(format, args...)
return fmt.Errorf("%s: %w", msg, err)
}
var (
// ErrNilConn indicates a nil net.Conn argument.
ErrNilConn = errors.New("starnet: nil connection")
// ErrTLSSniffFailed indicates TLS sniffing/parsing failed before handshake setup.
ErrTLSSniffFailed = errors.New("starnet: tls sniff failed")
// ErrTLSConfigSelectionFailed indicates dynamic TLS config selection failed.
ErrTLSConfigSelectionFailed = errors.New("starnet: tls config selection failed")
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
// ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP.
ErrNotTLS = errors.New("starnet: connection is not TLS")
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
)
// ErrorKind is a normalized high-level category for request errors.
type ErrorKind string
const (
ErrorKindNone ErrorKind = "none"
ErrorKindCanceled ErrorKind = "canceled"
ErrorKindTimeout ErrorKind = "timeout"
ErrorKindDNS ErrorKind = "dns"
ErrorKindTLS ErrorKind = "tls"
ErrorKindProxy ErrorKind = "proxy"
ErrorKindOther ErrorKind = "other"
)
// IsCanceled reports whether err is a cancellation-related error.
func IsCanceled(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "context canceled") ||
strings.Contains(msg, "operation was canceled") ||
strings.Contains(msg, "request canceled")
}
// ClassifyError maps low-level errors to a stable category for business handling.
func ClassifyError(err error) ErrorKind {
if err == nil {
return ErrorKindNone
}
if IsCanceled(err) {
return ErrorKindCanceled
}
if IsProxy(err) {
return ErrorKindProxy
}
if IsDNS(err) {
return ErrorKindDNS
}
if IsTLS(err) {
return ErrorKindTLS
}
if IsTimeout(err) {
return ErrorKindTimeout
}
return ErrorKindOther
}
// IsTimeout reports whether err is a timeout-related error.
func IsTimeout(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) && uerr.Timeout() {
return true
}
var nerr net.Error
if errors.As(err, &nerr) && nerr.Timeout() {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded")
}
// IsDNS reports whether err is a DNS resolution related error.
func IsDNS(err error) bool {
if err == nil {
return false
}
var derr *net.DNSError
if errors.As(err, &derr) {
return true
}
msg := strings.ToLower(err.Error())
if strings.Contains(msg, "no such host") ||
strings.Contains(msg, "server misbehaving") ||
strings.Contains(msg, "temporary failure in name resolution") {
return true
}
return strings.Contains(msg, "lookup ") &&
(strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout"))
}
// IsTLS reports whether err is TLS/Certificate related.
func IsTLS(err error) bool {
if err == nil {
return false
}
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) ||
errors.Is(err, ErrTLSSniffFailed) || errors.Is(err, ErrTLSConfigSelectionFailed) {
return true
}
var recErr tls.RecordHeaderError
if errors.As(err, &recErr) {
return true
}
var uaErr x509.UnknownAuthorityError
if errors.As(err, &uaErr) {
return true
}
var hnErr x509.HostnameError
if errors.As(err, &hnErr) {
return true
}
var certErr x509.CertificateInvalidError
if errors.As(err, &certErr) {
return true
}
var rootsErr x509.SystemRootsError
if errors.As(err, &rootsErr) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:")
}
// IsProxy reports whether err is proxy related.
func IsProxy(err error) bool {
if err == nil {
return false
}
if isProxyMessage(strings.ToLower(err.Error())) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) {
if strings.Contains(strings.ToLower(uerr.Op), "proxy") {
return true
}
if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) {
return true
}
}
var opErr *net.OpError
if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") {
return true
}
return false
}
func isProxyMessage(msg string) bool {
return strings.Contains(msg, "proxyconnect") ||
strings.Contains(msg, "proxy error") ||
strings.Contains(msg, "proxy authentication required") ||
strings.Contains(msg, "proxy: unknown scheme") ||
strings.Contains(msg, "socks connect") ||
strings.Contains(msg, "socks5")
}
+116
View File
@@ -0,0 +1,116 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"testing"
)
type timeoutErr struct{}
func (timeoutErr) Error() string { return "i/o timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
func TestIsTimeout(t *testing.T) {
if !IsTimeout(context.DeadlineExceeded) {
t.Fatal("context deadline should be timeout")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: timeoutErr{},
}
if !IsTimeout(uerr) {
t.Fatal("url timeout error should be timeout")
}
if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) {
t.Fatal("wrapped timeout should be timeout")
}
if IsTimeout(errors.New("plain error")) {
t.Fatal("plain error must not be timeout")
}
}
func TestIsDNS(t *testing.T) {
dnsErr := &net.DNSError{
Err: "no such host",
Name: "example.invalid",
IsNotFound: true,
}
if !IsDNS(dnsErr) {
t.Fatal("dns error should be dns")
}
if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) {
t.Fatal("wrapped dns error should be dns")
}
if !IsDNS(errors.New("lookup example.invalid: no such host")) {
t.Fatal("lookup no such host should be dns")
}
if IsDNS(errors.New("connection reset by peer")) {
t.Fatal("non dns error should not be dns")
}
}
func TestIsTLS(t *testing.T) {
tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"}
if !IsTLS(tlsErr) {
t.Fatal("tls record header error should be tls")
}
if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) {
t.Fatal("wrapped tls error should be tls")
}
if !IsTLS(errors.New("x509: certificate signed by unknown authority")) {
t.Fatal("x509 error text should be tls")
}
if !IsTLS(ErrNotTLS) {
t.Fatal("ErrNotTLS should be tls related")
}
if IsTLS(errors.New("plain error")) {
t.Fatal("plain error should not be tls")
}
}
func TestIsProxy(t *testing.T) {
raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused")
if !IsProxy(raw) {
t.Fatal("proxyconnect error should be proxy")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: raw,
}
if !IsProxy(uerr) {
t.Fatal("wrapped proxy error should be proxy")
}
opErr := &net.OpError{
Op: "proxyconnect",
Net: "tcp",
Err: errors.New("connect failed"),
}
if !IsProxy(opErr) {
t.Fatal("net.OpError proxyconnect should be proxy")
}
if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) {
t.Fatal("non proxy dial error should not be proxy")
}
}
+49
View File
@@ -0,0 +1,49 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
)
func TestIsCanceled(t *testing.T) {
if !IsCanceled(context.Canceled) {
t.Fatal("context canceled should be canceled")
}
if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatal("wrapped context canceled should be canceled")
}
if IsCanceled(context.DeadlineExceeded) {
t.Fatal("deadline exceeded must not be canceled")
}
}
func TestClassifyError(t *testing.T) {
dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true}
tlsErr := tls.RecordHeaderError{Msg: "bad tls record"}
tests := []struct {
name string
err error
want ErrorKind
}{
{name: "nil", err: nil, want: ErrorKindNone},
{name: "canceled", err: context.Canceled, want: ErrorKindCanceled},
{name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy},
{name: "dns", err: dnsErr, want: ErrorKindDNS},
{name: "tls", err: tlsErr, want: ErrorKindTLS},
{name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout},
{name: "other", err: errors.New("boom"), want: ErrorKindOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ClassifyError(tt.err); got != tt.want {
t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err)
}
})
}
}
+256
View File
@@ -0,0 +1,256 @@
package starnet_test
import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"time"
"b612.me/starnet"
)
func ExampleGet() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Hello, World!
}
func ExamplePost() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Posted"))
}))
defer server.Close()
resp, err := starnet.Post(server.URL,
starnet.WithBodyString("test data"))
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Posted
}
func ExampleNewSimpleRequest() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
req := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom", "value").
AddQuery("name", "test")
resp, err := req.Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}
func ExampleClient_Get() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Client GET"))
}))
defer server.Close()
client := starnet.NewClientNoErr(
starnet.WithTimeout(10*time.Second),
starnet.WithUserAgent("MyApp/1.0"),
)
resp, err := client.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Client GET
}
func ExampleRequest_SetJSON() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type User struct {
Name string `json:"name"`
Email string `json:"email"`
}
user := User{Name: "John", Email: "john@example.com"}
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
SetJSON(user).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
fmt.Println(result["status"])
// Output: ok
}
func ExampleRequest_AddFormData() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
fmt.Fprintf(w, "name=%s", r.FormValue("name"))
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("age", "30").
Do()
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: name=John
}
func ExampleRequest_SetSkipTLSVerify() {
// This example shows how to skip TLS verification
// Useful for testing with self-signed certificates
req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET").
SetSkipTLSVerify(true)
// In a real scenario, you would call req.Do()
fmt.Println(req.Method())
// Output: GET
}
func ExampleRequest_Clone() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
baseReq := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-API-Key", "secret")
// Clone and modify
req1 := baseReq.Clone().AddQuery("page", "1")
req2 := baseReq.Clone().AddQuery("page", "2")
resp1, _ := req1.Do()
resp2, _ := req2.Do()
defer resp1.Close()
defer resp2.Close()
fmt.Println(resp1.StatusCode, resp2.StatusCode)
// Output: 200 200
}
func ExampleClient_SetDefaultSkipTLSVerify() {
client := starnet.NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
// All requests from this client will skip TLS verification
// unless overridden at request level
fmt.Println("Client configured")
// Output: Client configured
}
func ExampleWithTimeout() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.Write([]byte("OK"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithTimeout(200*time.Millisecond))
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}
func ExampleWithRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithRetry(2,
starnet.WithRetryBackoff(0, 0, 1),
starnet.WithRetryJitter(0),
),
)
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 3
}
func ExampleRequest_SetRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost).
SetBodyString("hello"). // 可重放 body 才能安全重试
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 2
}
+172
View File
@@ -0,0 +1,172 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestRequestAddFileStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20) // 10 MB
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
fileContent := "test file content"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "test.txt:" + fileContent
if body != expected {
t.Errorf("Body = %v; want %v", body, expected)
}
}
func TestRequestAddFileWithFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
// Check form field
name := r.FormValue("name")
if name != "John" {
t.Errorf("name = %v; want John", name)
}
// Check file
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
w.Write([]byte("OK:" + header.Filename))
}))
defer server.Close()
fileContent := "file data"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFileStream("file", "document.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, "document.txt") {
t.Errorf("Body should contain filename, got: %v", body)
}
}
func TestRequestUploadProgress(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
progressCalled := false
var lastUploaded int64
fileContent := strings.Repeat("a", 1024*10) // 10KB
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
SetUploadProgress(func(filename string, uploaded, total int64) {
progressCalled = true
lastUploaded = uploaded
if filename != "test.txt" {
t.Errorf("filename = %v; want test.txt", filename)
}
}).
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if !progressCalled {
t.Error("Progress callback was not called")
}
if lastUploaded != int64(len(fileContent)) {
t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent))
}
}
// TestRequestAddFileFromDisk tests uploading a real file from disk
func TestRequestAddFileFromDisk(t *testing.T) {
// Create a temporary file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
fileContent := []byte("test file content from disk")
err := os.WriteFile(tmpFile, fileContent, 0644)
if err != nil {
t.Fatalf("WriteFile error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, string(fileContent)) {
t.Errorf("Body should contain file content, got: %v", body)
}
}
+3
View File
@@ -0,0 +1,3 @@
module b612.me/starnet
go 1.16
View File
+140
View File
@@ -0,0 +1,140 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := make(map[string]string)
for k, v := range r.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
json.NewEncoder(w).Encode(headers)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom-Header", "value1").
AddHeader("X-Multi-Header", "value1").
AddHeader("X-Multi-Header", "value2")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var headers map[string]string
resp.Body().JSON(&headers)
if headers["X-Custom-Header"] != "value1" {
t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"])
}
}
func TestRequestCookies(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookies := make(map[string]string)
for _, cookie := range r.Cookies() {
cookies[cookie.Name] = cookie.Value
}
json.NewEncoder(w).Encode(cookies)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddSimpleCookie("session", "abc123").
AddSimpleCookie("user", "john")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var cookies map[string]string
resp.Body().JSON(&cookies)
if cookies["session"] != "abc123" {
t.Errorf("session cookie = %v; want abc123", cookies["session"])
}
if cookies["user"] != "john" {
t.Errorf("user cookie = %v; want john", cookies["user"])
}
}
func TestRequestUserAgent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.UserAgent()))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetUserAgent("CustomAgent/1.0")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "CustomAgent/1.0" {
t.Errorf("User-Agent = %v; want CustomAgent/1.0", body)
}
}
func TestRequestBearerToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
w.Write([]byte(auth))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBearerToken("mytoken123")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "Bearer mytoken123"
if body != expected {
t.Errorf("Authorization = %v; want %v", body, expected)
}
}
func TestRequestBasicAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Write([]byte(username + ":" + password))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBasicAuth("user", "pass")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "user:pass" {
t.Errorf("BasicAuth = %v; want user:pass", body)
}
}
+150
View File
@@ -0,0 +1,150 @@
package starnet
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetURL("https://other.example")
if req.Err() != nil {
t.Fatalf("unexpected request error: %v", req.Err())
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
}
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
cfg := &tls.Config{InsecureSkipVerify: true}
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSConfig(cfg).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if cfg.ServerName != "" {
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig == cfg {
t.Fatal("expected injected tls config to be cloned")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
rc := getRequestContext(req.execCtx)
if rc.TLSConfig == nil {
t.Fatal("expected injected tls config")
}
if rc.TLSConfig.ServerName != "override.example" {
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
}
}
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
t.Fatalf("unexpected request context injection: %#v", got)
}
rc := getRequestContext(req.execCtx)
if needsDynamicTransport(rc) {
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
}
if rc.TLSServerName != "" {
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
}
}
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
req := NewSimpleRequest("https://example.com", http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
if err := req.prepare(); err != nil {
t.Fatalf("prepare error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected aggregated request context, got %#v", raw)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatal("expected tls config with skip verify")
}
if rc.TLSServerName != "example.com" {
t.Fatalf("default tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "override.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
SetHost("override.example").
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
}
func TestWithHostOverridesRequestHost(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host != "option.example" {
t.Fatalf("host=%q", r.Host)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
got, err := resp.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer got.Close()
}
+258
View File
@@ -0,0 +1,258 @@
package starnet
import (
"os"
"testing"
"time"
)
// 这些测试使用 httpbin.org 作为测试服务
// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用
func skipIfNoIntegration(t *testing.T) {
if os.Getenv("STARNET_INTEGRATION_TEST") != "1" {
t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run")
}
}
func TestIntegrationHTTPBinGet(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/get",
WithQuery("name", "starnet"),
WithQuery("version", "1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
args, ok := result["args"].(map[string]interface{})
if !ok {
t.Fatal("args not found in response")
}
if args["name"] != "starnet" {
t.Errorf("args[name] = %v; want starnet", args["name"])
}
}
func TestIntegrationHTTPBinPost(t *testing.T) {
skipIfNoIntegration(t)
type PostData struct {
Name string `json:"name"`
Email string `json:"email"`
}
data := PostData{
Name: "John Doe",
Email: "john@example.com",
}
resp, err := Post("https://httpbin.org/post", WithJSON(data))
if err != nil {
t.Fatalf("Post() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
jsonData, ok := result["json"].(map[string]interface{})
if !ok {
t.Fatal("json not found in response")
}
if jsonData["name"] != data.Name {
t.Errorf("name = %v; want %v", jsonData["name"], data.Name)
}
}
func TestIntegrationHTTPBinHeaders(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/headers",
WithHeader("X-Custom-Header", "test-value"),
WithUserAgent("Starnet-Test/1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
headers, ok := result["headers"].(map[string]interface{})
if !ok {
t.Fatal("headers not found in response")
}
if headers["X-Custom-Header"] != "test-value" {
t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"])
}
}
func TestIntegrationHTTPBinBasicAuth(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/basic-auth/user/passwd",
WithBasicAuth("user", "passwd"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["authenticated"] != true {
t.Error("authenticated should be true")
}
}
func TestIntegrationHTTPBinDelay(t *testing.T) {
skipIfNoIntegration(t)
// Test timeout
start := time.Now()
_, err := Get("https://httpbin.org/delay/3",
WithTimeout(1*time.Second))
elapsed := time.Since(start)
if err == nil {
t.Error("Expected timeout error, got nil")
}
if elapsed > 2*time.Second {
t.Errorf("Timeout took too long: %v", elapsed)
}
}
func TestIntegrationHTTPBinRedirect(t *testing.T) {
skipIfNoIntegration(t)
// Test with redirect enabled
client := NewClientNoErr()
resp, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode)
}
// Test with redirect disabled
client.DisableRedirect()
resp2, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != 302 {
t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode)
}
}
func TestIntegrationHTTPBinCookies(t *testing.T) {
skipIfNoIntegration(t)
// 创建一个禁用重定向的 Client
client := NewClientNoErr()
client.DisableRedirect()
resp, err := client.Get("https://httpbin.org/cookies/set?name=value")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// 现在应该能获取到 Set-Cookie
cookies := resp.Cookies()
if len(cookies) == 0 {
t.Error("Expected cookies in response")
}
// 验证 cookie
found := false
for _, cookie := range cookies {
if cookie.Name == "name" && cookie.Value == "value" {
found = true
break
}
}
if !found {
t.Error("Expected cookie 'name=value' not found")
}
}
func TestIntegrationHTTPBinUserAgent(t *testing.T) {
skipIfNoIntegration(t)
customUA := "Starnet-Integration-Test/1.0"
resp, err := Get("https://httpbin.org/user-agent",
WithUserAgent(customUA))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["user-agent"] != customUA {
t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA)
}
}
func TestIntegrationHTTPBinGzip(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/gzip")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["gzipped"] != true {
t.Error("Response should be gzipped")
}
}
+230
View File
@@ -0,0 +1,230 @@
package pingcore
import (
"encoding/binary"
"net"
"os"
"sync/atomic"
"time"
)
const icmpHeaderLen = 8
type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type Options struct {
Count int
Timeout time.Duration
Interval time.Duration
Deadline time.Time
PreferIPv4 bool
PreferIPv6 bool
SourceIP net.IP
PayloadSize int
}
type Result struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var identifierSeed uint32
func NextIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&identifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func Payload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for index := 0; index < len(payload); index++ {
payload[index] = byte(index)
}
return payload
}
func BuildICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := MarshalPacket(icmp, payload)
icmp.CheckSum = Checksum(buf)
return icmp
}
func Checksum(data []byte) uint16 {
var (
sum uint32
length = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
}
func Marshal(icmp ICMP) []byte {
return MarshalPacket(icmp, nil)
}
func MarshalPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload))
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
}
func IsExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, offset := range CandidateICMPOffsets(packet, family) {
if offset < 0 || offset+icmpHeaderLen > len(packet) {
continue
}
if packet[offset] != expectedType || packet[offset+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[offset+4:offset+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[offset+6:offset+8]) != seq {
continue
}
return true
}
return false
}
func CandidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
version := packet[0] >> 4
if version == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if version == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return DedupOffsets(offsets)
}
func DedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
seen := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, offset := range offsets {
if _, ok := seen[offset]; ok {
continue
}
seen[offset] = struct{}{}
out = append(out, offset)
}
return out
}
func ResolveTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, err := net.ResolveIPAddr("ip4", host); err == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = err
}
if ip6, err := net.ResolveIPAddr("ip6", host); err == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = err
}
if len(targets) > 0 {
return OrderTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, nil
}
func OrderTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
return ordered
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() == nil {
ordered = append(ordered, target)
}
}
for _, target := range targets {
if target != nil && target.IP != nil && target.IP.To4() != nil {
ordered = append(ordered, target)
}
}
return ordered
}
+123
View File
@@ -0,0 +1,123 @@
package tlssniffercore
import "crypto/tls"
func ComposeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
ApplyServerTLSOverrides(out, selected)
return out
}
func ApplyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
copied := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for name, cert := range src.NameToCertificate {
copied[name] = cert
}
dst.NameToCertificate = copied
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
}
+237
View File
@@ -0,0 +1,237 @@
package tlssniffercore
import (
"bytes"
"encoding/binary"
"io"
"net"
)
type ClientHelloMeta struct {
ServerName string
LocalAddr net.Addr
RemoteAddr net.Addr
SupportedProtos []string
SupportedVersions []uint16
CipherSuites []uint16
}
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
}
type Sniffer struct{}
func (s Sniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
}
if isTLS {
out.ClientHello = meta
}
return out, nil
}
func sniffClientHello(reader io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
}
header, complete := readTLSRecordHeader(reader, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(reader, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, ok := readTLSRecordHeader(reader, buf)
if len(nextHeader) < 5 || !ok {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, bodyOK := readBufferedBytes(reader, buf, nextLen)
if !bodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(reader io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(reader, buf, 5)
}
func readBufferedBytes(reader io.Reader, buf *bytes.Buffer, count int) ([]byte, bool) {
if count <= 0 {
return nil, true
}
tmp := make([]byte, count)
readN, err := io.ReadFull(reader, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for index := 0; index+1 < cipherSuitesLen; index += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+index:offset+index+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
}
}
+112
View File
@@ -0,0 +1,112 @@
package starnet
import (
"encoding/json"
"io"
"os"
)
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, body)
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
setBytesBodyConfig(&r.config.Body, []byte(body))
return nil
}
}
// WithBodyReader 设置请求体(Reader)。
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
setReaderBodyConfig(&r.config.Body, reader)
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
setBytesBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, data)
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
setFormBodyConfig(&r.config.Body, nil)
for key, value := range data {
r.config.Body.FormData[key] = []string{value}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}
+137
View File
@@ -0,0 +1,137 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateTimeout(timeout))
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return requestOptFromMutation(mutateDialTimeout(timeout))
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return requestOptFromMutation(mutateProxy(proxy))
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return requestOptFromMutation(mutateDialFunc(fn))
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return requestOptFromMutation(mutateTLSConfig(tlsConfig))
}
// WithTLSServerName 设置显式 TLS ServerName/SNI。
func WithTLSServerName(serverName string) RequestOpt {
return requestOptFromMutation(mutateTLSServerName(serverName))
}
// WithTraceHooks 设置请求 trace 回调。
func WithTraceHooks(hooks *TraceHooks) RequestOpt {
return requestOptFromMutation(mutateTraceHooks(hooks))
}
// WithTraceRecorder 设置请求级 trace 摘要记录器。
func WithTraceRecorder(recorder *TraceRecorder) RequestOpt {
return requestOptFromMutation(mutateTraceRecorder(recorder))
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return requestOptFromMutation(mutateSkipTLSVerify(skip))
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return requestOptFromMutation(mutateCustomIP(ips))
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return requestOptFromMutation(mutateAddCustomIP(ip))
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return requestOptFromMutation(mutateCustomDNS(dnsServers))
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return requestOptFromMutation(mutateAddCustomDNS(dns))
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return requestOptFromMutation(mutateLookupFunc(fn))
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return requestOptFromMutation(mutateBasicAuth(username, password))
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return requestOptFromMutation(mutateAddQuery(key, value))
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return requestOptFromMutation(mutateAddQueries(queries))
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return requestOptFromMutation(mutateContentLength(length))
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoCalcContentLength(auto))
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return requestOptFromMutation(mutateUploadProgress(fn))
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return requestOptFromMutation(mutateTransport(transport))
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return requestOptFromMutation(mutateAutoFetch(auto))
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return requestOptFromMutation(mutateMaxRespBodyBytes(maxBytes))
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return requestOptFromMutation(mutateRawRequest(httpReq))
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return requestOptFromMutation(mutateContext(ctx))
}
+99
View File
@@ -0,0 +1,99 @@
package starnet
import "net/http"
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
return nil
}
r.config.Headers.Set(key, value)
return nil
}
}
// WithHost 设置显式 Host 头覆盖。
func WithHost(host string) RequestOpt {
return func(r *Request) error {
setRequestHostConfig(r.config, host)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for key, value := range headers {
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, value)
continue
}
r.config.Headers.Set(key, value)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}
+234
View File
@@ -0,0 +1,234 @@
package starnet
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithJSONOpt(t *testing.T) {
type payload struct {
Name string `json:"name"`
Age int `json:"age"`
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
t.Fatalf("content-type=%s", ct)
}
var p payload
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
t.Fatalf("decode err: %v", err)
}
if p.Name != "alice" || p.Age != 18 {
t.Fatalf("payload mismatch: %+v", p)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileOpt(t *testing.T) {
// temp file + cleanup
f, err := os.CreateTemp("", "starnet-upload-*.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
_, _ = f.WriteString("hello-file")
_ = f.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename == "" || string(b) != "hello-file" {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFile("file", f.Name()))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileStreamOpt(t *testing.T) {
content := "stream-content"
reader := strings.NewReader(content)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("up")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename != "a.txt" || string(b) != content {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithQueryOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("k") != "v" {
t.Fatalf("query mismatch: %v", r.URL.Query())
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithQuery("k", "v"))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithUploadProgressOpt(t *testing.T) {
var called int32
var last int64
content := strings.Repeat("x", 4096)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL,
WithUploadProgress(func(filename string, uploaded, total int64) {
atomic.StoreInt32(&called, 1)
last = uploaded
}),
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
)
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
if atomic.LoadInt32(&called) == 0 {
t.Fatal("progress not called")
}
if last != int64(len(content)) {
t.Fatalf("last uploaded=%d want=%d", last, len(content))
}
}
func TestWithTransportOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithContextOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := Get(s.URL, WithContext(ctx))
if err == nil {
t.Fatal("expected context timeout error")
}
}
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomDNS) != 2 {
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
}
}
func TestWithAddCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
}
}
func TestWithCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 2 {
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
}
}
func TestWithDialFuncOpt(t *testing.T) {
called := int32(0)
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.StoreInt32(&called, 1)
return nil, io.EOF
}
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
if req.config.Network.DialFunc == nil {
t.Fatal("dial func not set")
}
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
if atomic.LoadInt32(&called) == 0 {
t.Fatal("dial func not called")
}
}
func TestWithDialTimeoutOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
if req.config.Network.DialTimeout != 123*time.Millisecond {
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
}
}
+375
View File
@@ -0,0 +1,375 @@
package starnet
import (
"context"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
"b612.me/starnet/internal/pingcore"
)
const (
icmpTypeEchoReplyV4 = 0
icmpTypeEchoRequestV4 = 8
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
defaultPingableCount = 3
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP = pingcore.ICMP
type pingSocketSpec struct {
network string
family int
requestType uint8
replyType uint8
}
// PingOptions controls ping probing behavior.
type PingOptions = pingcore.Options
type PingResult = pingcore.Result
func nextPingIdentifier() uint16 {
return pingcore.NextIdentifier()
}
func pingPayload(size int) []byte {
return pingcore.Payload(size)
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
return pingcore.BuildICMP(seq, identifier, typ, payload)
}
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
if destAddr == nil || destAddr.IP == nil {
return res, fmt.Errorf("destination ip is nil")
}
res.RemoteIP = destAddr.String()
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
if err != nil {
return res, err
}
conn, err := net.DialIP(spec.network, localAddr, destAddr)
if err != nil {
return res, normalizePingDialError(err)
}
defer conn.Close()
packet := marshalICMPPacket(icmp, payload)
if _, err := conn.Write(packet); err != nil {
return res, wrapError(err, "ping write request")
}
startedAt := time.Now()
deadline := startedAt.Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
}
if err := conn.SetReadDeadline(deadline); err != nil {
return res, wrapError(err, "ping set read deadline")
}
doneCh := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
case <-doneCh:
}
}()
defer close(doneCh)
recv := make([]byte, icmpReadBufSz)
for {
n, err := conn.Read(recv)
if err != nil {
if ctx.Err() != nil {
return res, wrapError(ctx.Err(), "ping context done")
}
return res, wrapError(err, "ping read reply")
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(startedAt)
return res, nil
}
}
}
func checkSum(data []byte) uint16 {
return pingcore.Checksum(data)
}
func marshalICMP(icmp ICMP) []byte {
return pingcore.Marshal(icmp)
}
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
return pingcore.MarshalPacket(icmp, payload)
}
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
}
func candidateICMPOffsets(packet []byte, family int) []int {
return pingcore.CandidateICMPOffsets(packet, family)
}
func dedupOffsets(offsets []int) []int {
return pingcore.DedupOffsets(offsets)
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
if ip == nil {
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
}
if ip4 := ip.To4(); ip4 != nil {
return pingSocketSpec{
network: "ip4:icmp",
family: 4,
requestType: icmpTypeEchoRequestV4,
replyType: icmpTypeEchoReplyV4,
}, nil
}
if ip16 := ip.To16(); ip16 != nil {
return pingSocketSpec{
network: "ip6:ipv6-icmp",
family: 6,
requestType: icmpTypeEchoRequestV6,
replyType: icmpTypeEchoReplyV6,
}, nil
}
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
}
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
if sourceIP == nil {
return nil, nil
}
if sourceIP.To16() == nil {
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
}
if family == 4 && sourceIP.To4() == nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
}
if family == 6 && sourceIP.To4() != nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
}
return &net.IPAddr{IP: sourceIP}, nil
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
if err != nil {
return nil, err
}
if len(targets) == 0 {
return nil, ErrPingNoResolvedTarget
}
return targets, nil
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
}
func normalizePingDialError(err error) error {
if err == nil {
return nil
}
msg := strings.ToLower(err.Error())
if errors.Is(err, os.ErrPermission) ||
strings.Contains(msg, "operation not permitted") ||
strings.Contains(msg, "permission denied") {
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
}
if strings.Contains(msg, "unknown network") ||
strings.Contains(msg, "protocol not available") ||
strings.Contains(msg, "address family not supported by protocol") ||
strings.Contains(msg, "socket type not supported") {
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
}
return wrapError(err, "ping dial")
}
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
out := PingOptions{
Count: defaultCount,
Timeout: defaultTimeout,
Interval: 0,
PayloadSize: 0,
}
if opts != nil {
out = *opts
if out.Count == 0 {
out.Count = defaultCount
}
if out.Timeout == 0 {
out.Timeout = defaultTimeout
}
}
if out.Count < 0 {
return out, fmt.Errorf("ping count must be >= 0")
}
if out.Timeout <= 0 {
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
}
if out.Interval < 0 {
return out, fmt.Errorf("ping interval must be >= 0")
}
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
}
if out.SourceIP != nil && out.SourceIP.To16() == nil {
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
}
return out, nil
}
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
if err != nil {
return res, wrapError(err, "resolve ping target")
}
payload := pingPayload(opts.PayloadSize)
var lastErr error
for _, target := range targets {
spec, err := socketSpecForIP(target.IP)
if err != nil {
lastErr = err
continue
}
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
if err == nil {
return resp, nil
}
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
lastErr = err
}
if lastErr != nil {
return res, wrapError(lastErr, "ping all resolved targets failed")
}
return res, ErrPingNoResolvedTarget
}
// PingWithContext sends one ICMP echo request with context cancel support.
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
opts, err := normalizePingOptions(&PingOptions{
Count: 1,
Timeout: timeout,
}, 1, timeout)
if err != nil {
return PingResult{}, err
}
if !opts.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
defer cancel()
}
return pingOnceWithOptions(ctx, host, seq, opts)
}
// Ping sends one ICMP echo request.
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
return PingWithContext(context.Background(), ip, seq, timeout)
}
// Pingable checks host reachability with retry options.
func Pingable(host string, opts *PingOptions) (bool, error) {
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
if err != nil {
return false, err
}
ctx := context.Background()
if !cfg.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
defer cancel()
}
var lastErr error
for index := 0; index < cfg.Count; index++ {
_, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
if err == nil {
return true, nil
}
lastErr = err
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
}
if index < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():
timer.Stop()
return false, wrapError(ctx.Err(), "pingable context done")
case <-timer.C:
}
}
}
if lastErr == nil {
lastErr = ErrPingNoResolvedTarget
}
return false, lastErr
}
// IsIpPingable keeps backward-compatible bool-only behavior.
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
if retryLimit <= 0 {
return false
}
ok, _ := Pingable(ip, &PingOptions{
Count: retryLimit,
Timeout: timeout,
})
return ok
}
+214
View File
@@ -0,0 +1,214 @@
package starnet
import (
"context"
"errors"
"net"
"testing"
"time"
)
func buildICMPPacket(typ uint8, identifier, seq uint16) []byte {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, nil)
cs := checkSum(buf)
buf[2] = byte(cs >> 8)
buf[3] = byte(cs)
return buf
}
func TestNextPingIdentifierChanges(t *testing.T) {
id1 := nextPingIdentifier()
id2 := nextPingIdentifier()
if id1 == id2 {
t.Fatalf("identifier should change between calls: %d == %d", id1, id2)
}
}
func TestIsExpectedEchoReplyIPv4(t *testing.T) {
identifier := uint16(0x1234)
seq := uint16(0x0102)
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 echo reply to match")
}
if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) {
t.Fatal("mismatched sequence should not match")
}
}
func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) {
identifier := uint16(0x1111)
seq := uint16(0x2222)
ipHeader := make([]byte, 20)
ipHeader[0] = 0x45 // version=4, ihl=5
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
packet := append(ipHeader, reply...)
if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 packet with header to match")
}
}
func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) {
identifier := uint16(0xabcd)
seq := uint16(0x00ff)
ipv6Header := make([]byte, 40)
ipv6Header[0] = 0x60 // version=6
reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq)
packet := append(ipv6Header, reply...)
if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) {
t.Fatal("expected IPv6 packet with header to match")
}
}
func TestPingInvalidTimeout(t *testing.T) {
_, err := Ping("127.0.0.1", 1, 0)
if err == nil {
t.Fatal("expected error for non-positive timeout")
}
if !errors.Is(err, ErrPingInvalidTimeout) {
t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err)
}
}
func TestIsIPPingableInvalidRetry(t *testing.T) {
if IsIpPingable("127.0.0.1", time.Millisecond, 0) {
t.Fatal("retryLimit=0 should return false")
}
}
func TestSocketSpecForIP(t *testing.T) {
v4, err := socketSpecForIP(net.ParseIP("127.0.0.1"))
if err != nil {
t.Fatalf("unexpected v4 error: %v", err)
}
if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 {
t.Fatalf("unexpected v4 spec: %+v", v4)
}
v6, err := socketSpecForIP(net.ParseIP("::1"))
if err != nil {
t.Fatalf("unexpected v6 error: %v", err)
}
if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 {
t.Fatalf("unexpected v6 spec: %+v", v6)
}
_, err = socketSpecForIP(nil)
if err == nil {
t.Fatal("expected error for nil ip")
}
if !errors.Is(err, ErrInvalidIP) {
t.Fatalf("expected ErrInvalidIP, got: %v", err)
}
}
func TestResolvePingTargetsLiteral(t *testing.T) {
v4, err := resolvePingTargets("127.0.0.1", false, false)
if err != nil {
t.Fatalf("unexpected v4 resolve error: %v", err)
}
if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil {
t.Fatalf("unexpected v4 targets: %+v", v4)
}
v6, err := resolvePingTargets("::1", false, false)
if err != nil {
t.Fatalf("unexpected v6 resolve error: %v", err)
}
if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil {
t.Fatalf("unexpected v6 targets: %+v", v6)
}
}
func TestNormalizePingDialError(t *testing.T) {
perr := normalizePingDialError(errors.New("socket: operation not permitted"))
if !errors.Is(perr, ErrPingPermissionDenied) {
t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr)
}
uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp"))
if !errors.Is(uerr, ErrPingProtocolUnsupported) {
t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr)
}
}
func TestOrderPingTargets(t *testing.T) {
targets := []*net.IPAddr{
{IP: net.ParseIP("::1")},
{IP: net.ParseIP("127.0.0.1")},
}
v4First := orderPingTargets(targets, true, false)
if v4First[0].IP.To4() == nil {
t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP)
}
v6First := orderPingTargets(targets, false, true)
if v6First[0].IP.To4() != nil {
t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP)
}
}
func TestNormalizePingOptions(t *testing.T) {
opts, err := normalizePingOptions(nil, 3, 2*time.Second)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if opts.Count != 3 || opts.Timeout != 2*time.Second {
t.Fatalf("unexpected defaults: %+v", opts)
}
_, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative count")
}
_, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative timeout")
}
_, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for too large payload")
}
}
func TestPingWithContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second)
if err == nil {
t.Fatal("expected canceled error")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got: %v", err)
}
}
func TestPingableInvalidOptions(t *testing.T) {
_, err := Pingable("127.0.0.1", &PingOptions{Count: -1})
if err == nil {
t.Fatal("expected invalid count error")
}
_, err = Pingable("127.0.0.1", &PingOptions{Interval: -1})
if err == nil {
t.Fatal("expected invalid interval error")
}
}
+15
View File
@@ -0,0 +1,15 @@
package starnet
import (
"fmt"
"testing"
"time"
)
func Test_Ping(t *testing.T) {
fmt.Println(Ping("baidu.com", 29, time.Second*2))
fmt.Println(Ping("www.b612.me", 29, time.Second*2))
fmt.Println(IsIpPingable("baidu.com", time.Second*2, 3))
fmt.Println(IsIpPingable("www.b612.me", time.Second*2, 3))
}
+110
View File
@@ -0,0 +1,110 @@
package starnet
import (
"fmt"
"net"
"net/http"
"testing"
)
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
proxyServer := newIPv4ConnectProxyServer(t, nil)
defer proxyServer.Close()
targetHost := "proxy-custom-ip.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
targets := proxyServer.Targets()
if len(targets) != 1 {
t.Fatalf("connect targets=%v; want 1 target", targets)
}
gotConnectTarget := targets[0]
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
if gotConnectTarget != wantConnectTarget {
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
}
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
targetHost := "custom-ip-direct.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}
+331
View File
@@ -0,0 +1,331 @@
package starnet
import (
"crypto/tls"
"crypto/x509"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type connectProxyServer struct {
*httptest.Server
mu sync.Mutex
targets []string
}
func newIPv4Server(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.Start()
return server
}
func newIPv4TLSServer(t testing.TB, handler http.Handler) *httptest.Server {
t.Helper()
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server := httptest.NewUnstartedServer(handler)
server.Listener = listener
server.StartTLS()
return server
}
func newTrustedIPv4TLSServer(t testing.TB, dnsName string, handler http.Handler) (*httptest.Server, *x509.CertPool) {
t.Helper()
testT, ok := t.(*testing.T)
if !ok {
t.Fatal("newTrustedIPv4TLSServer requires *testing.T")
}
certPEM, keyPEM := genSelfSignedCertPEM(testT, dnsName)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("X509KeyPair: %v", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(certPEM) {
t.Fatal("AppendCertsFromPEM returned false")
}
server := httptest.NewUnstartedServer(handler)
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4: %v", err)
}
server.Listener = listener
server.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
server.StartTLS()
return server, pool
}
func httpsURLForHost(t testing.TB, server *httptest.Server, host string) string {
t.Helper()
_, port, err := net.SplitHostPort(server.Listener.Addr().String())
if err != nil {
t.Fatalf("split host port: %v", err)
}
return fmt.Sprintf("https://%s:%s", host, port)
}
func newIPv4ConnectProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *connectProxyServer {
t.Helper()
proxy := &connectProxyServer{}
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
proxy.Server = newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
proxy.mu.Lock()
proxy.targets = append(proxy.targets, r.Host)
proxy.mu.Unlock()
targetConn, err := dialTarget(r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
return proxy
}
func (p *connectProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
type socks5ProxyServer struct {
ln net.Listener
addr string
dial func(target string) (net.Conn, error)
stopCh chan struct{}
wg sync.WaitGroup
mu sync.Mutex
targets []string
}
func newSOCKS5ProxyServer(t testing.TB, dialTarget func(target string) (net.Conn, error)) *socks5ProxyServer {
t.Helper()
if dialTarget == nil {
dialTarget = func(target string) (net.Conn, error) {
return net.Dial("tcp", target)
}
}
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp4 socks5: %v", err)
}
proxy := &socks5ProxyServer{
ln: ln,
addr: ln.Addr().String(),
dial: dialTarget,
stopCh: make(chan struct{}),
}
proxy.wg.Add(1)
go func() {
defer proxy.wg.Done()
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-proxy.stopCh:
return
default:
return
}
}
proxy.wg.Add(1)
go func(c net.Conn) {
defer proxy.wg.Done()
proxy.handleConn(t, c)
}(conn)
}
}()
return proxy
}
func (p *socks5ProxyServer) URL() string {
return "socks5://" + p.addr
}
func (p *socks5ProxyServer) Targets() []string {
p.mu.Lock()
defer p.mu.Unlock()
return append([]string(nil), p.targets...)
}
func (p *socks5ProxyServer) Close() {
close(p.stopCh)
_ = p.ln.Close()
p.wg.Wait()
}
func (p *socks5ProxyServer) handleConn(t testing.TB, conn net.Conn) {
t.Helper()
closeConn := true
defer func() {
if closeConn {
_ = conn.Close()
}
}()
header := make([]byte, 2)
if _, err := io.ReadFull(conn, header); err != nil {
return
}
if header[0] != 0x05 {
return
}
methods := make([]byte, int(header[1]))
if _, err := io.ReadFull(conn, methods); err != nil {
return
}
if _, err := conn.Write([]byte{0x05, 0x00}); err != nil {
return
}
reqHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, reqHeader); err != nil {
return
}
if reqHeader[0] != 0x05 || reqHeader[1] != 0x01 {
_, _ = conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
host, err := readSOCKS5Addr(conn, reqHeader[3])
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
portBytes := make([]byte, 2)
if _, err := io.ReadFull(conn, portBytes); err != nil {
return
}
target := net.JoinHostPort(host, fmt.Sprintf("%d", binary.BigEndian.Uint16(portBytes)))
p.mu.Lock()
p.targets = append(p.targets, target)
p.mu.Unlock()
targetConn, err := p.dial(target)
if err != nil {
_, _ = conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
return
}
if _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}); err != nil {
targetConn.Close()
return
}
closeConn = false
relayProxyConns(conn, targetConn)
}
func readSOCKS5Addr(r io.Reader, atyp byte) (string, error) {
switch atyp {
case 0x01:
buf := make([]byte, 4)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
case 0x03:
var size [1]byte
if _, err := io.ReadFull(r, size[:]); err != nil {
return "", err
}
buf := make([]byte, int(size[0]))
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return string(buf), nil
case 0x04:
buf := make([]byte, 16)
if _, err := io.ReadFull(r, buf); err != nil {
return "", err
}
return net.IP(buf).String(), nil
default:
return "", fmt.Errorf("unsupported atyp: %d", atyp)
}
}
func relayProxyConns(left, right net.Conn) {
var once sync.Once
closeBoth := func() {
_ = left.Close()
_ = right.Close()
}
go func() {
_, _ = io.Copy(left, right)
once.Do(closeBoth)
}()
go func() {
_, _ = io.Copy(right, left)
once.Do(closeBoth)
}()
}
+50
View File
@@ -0,0 +1,50 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestProxy(t *testing.T) {
// Create a proxy server
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Proxy received the request
w.Header().Set("X-Proxied", "true")
w.WriteHeader(http.StatusOK)
w.Write([]byte("proxied"))
}))
defer proxyServer.Close()
// Note: This is a simplified test. Real proxy testing requires more setup
req := NewSimpleRequest("http://example.com", "GET").
SetProxy(proxyServer.URL)
// Just verify the proxy is set in config
if req.config.Network.Proxy != proxyServer.URL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL)
}
}
func TestClientLevelProxy(t *testing.T) {
proxyURL := "http://proxy.example.com:8080"
client := NewClientNoErr(WithProxy(proxyURL))
req, _ := client.NewRequest("http://example.com", "GET")
if req.config.Network.Proxy != proxyURL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL)
}
}
func TestRequestLevelProxyOverride(t *testing.T) {
clientProxy := "http://client-proxy.com:8080"
requestProxy := "http://request-proxy.com:8080"
client := NewClientNoErr(WithProxy(clientProxy))
req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy))
// Request level should override client level
if req.config.Network.Proxy != requestProxy {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy)
}
}
-295
View File
@@ -1,295 +0,0 @@
package starnet
import (
"bytes"
"context"
"encoding/binary"
"errors"
"sync"
"time"
)
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
Encode bool
Reserve uint16
Msgid uint16
MsgPool []MsgQueue
UnFinMsg sync.Map
LastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//parseMu sync.Mutex
restoreMu sync.Mutex
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
var que StarQueue
que.Encode = false
que.ctx, que.cancel = context.WithCancel(context.Background())
que.duration = 0
return &que
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (que *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
que.Msgid++
if que.Encode {
src = que.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (que *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
que.Msgid++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头,应当为14位:8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
// ParseMessage 用于解析收到的msg信息
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
tmp, ok := que.UnFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
que.UnFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return que.ParseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
que.UnFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
stroeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
que.MsgPool = append(que.MsgPool, stroeMsg)
que.UnFinMsg.Delete(conn)
return que.ParseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
que.UnFinMsg.Delete(conn)
return que.ParseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
stroeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
que.MsgPool = append(que.MsgPool, stroeMsg)
que.UnFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return que.ParseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
que.UnFinMsg.Store(conn, &lastMsg)
return que.ParseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (que *StarQueue) Restore(n int) ([]MsgQueue, error) {
que.restoreMu.Lock()
defer que.restoreMu.Unlock()
var res []MsgQueue
dura := time.Duration(0)
for len(que.MsgPool) < n {
select {
case <-que.ctx.Done():
return res, errors.New("Stoped By External Function Call")
default:
time.Sleep(time.Millisecond * 20)
dura = time.Millisecond*20 + dura
if que.duration != 0 && dura > que.duration {
return res, errors.New("Time Exceed")
}
}
}
if len(que.MsgPool) < n {
return res, errors.New("Result Not Enough")
}
res = que.MsgPool[0:n]
que.MsgPool = que.MsgPool[n:]
return res, nil
}
// RestoreOne 获取收到的一个信息
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
data, err := que.Restore(1)
if len(data) == 1 {
return data[0], err
}
return MsgQueue{}, err
}
// Stop 立即停止Restore
func (que *StarQueue) Stop() {
que.cancel()
}
// RestoreDuration Restore最大超时时间
func (que *StarQueue) RestoreDuration(tm time.Duration) {
que.duration = tm
}
+98
View File
@@ -0,0 +1,98 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string][]string)
for k, v := range query {
result[k] = v
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("name", "John").
AddQuery("age", "30").
AddQuery("tags", "go").
AddQuery("tags", "http")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string][]string
resp.Body().JSON(&result)
if len(result["name"]) != 1 || result["name"][0] != "John" {
t.Errorf("name = %v; want [John]", result["name"])
}
if len(result["tags"]) != 2 {
t.Errorf("tags length = %v; want 2", len(result["tags"]))
}
}
func TestRequestSetQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
w.Write([]byte(query.Get("key")))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetQuery("key", "value1").
SetQuery("key", "value2") // Should overwrite
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "value2" {
t.Errorf("query value = %v; want value2", body)
}
}
func TestRequestDeleteQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string]string)
for k := range query {
result[k] = query.Get(k)
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("keep", "yes").
AddQuery("delete", "no").
DeleteQuery("delete")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if _, exists := result["delete"]; exists {
t.Error("delete query should not exist")
}
if result["keep"] != "yes" {
t.Errorf("keep = %v; want yes", result["keep"])
}
}
+607
View File
@@ -0,0 +1,607 @@
package starnet
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
)
// Request HTTP 请求
type Request struct {
ctx context.Context
execCtx context.Context // 执行时的 context(注入了配置)
cancel context.CancelFunc
url string
method string
err error // 累积的错误
config *RequestConfig
client *Client
httpClient *http.Client
httpReq *http.Request
retry *retryPolicy
traceHooks *TraceHooks
traceRecorder *TraceRecorder
traceRun *TraceRecorder
lastTraceSummary *TraceSummary
traceState *traceState
applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改)
autoFetch bool // 是否自动获取响应体
rawSourceExternal bool // 是否由 SetRawRequest/WithRawRequest 注入外部 raw request
rawTemplate *http.Request
}
func normalizeContext(ctx context.Context) context.Context {
if ctx != nil {
return ctx
}
return context.Background()
}
func cloneRawHTTPRequest(httpReq *http.Request, ctx context.Context) (*http.Request, error) {
if httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
cloned := httpReq.Clone(normalizeContext(ctx))
switch {
case httpReq.Body == nil || httpReq.Body == http.NoBody:
cloned.Body = httpReq.Body
case httpReq.GetBody != nil:
body, err := httpReq.GetBody()
if err != nil {
return cloned, wrapError(err, "clone raw request body")
}
cloned.Body = body
default:
return cloned, fmt.Errorf("cannot clone raw request with non-replayable body")
}
return cloned, nil
}
func (r *Request) rawBaseRequest() *http.Request {
if r == nil {
return nil
}
if r.rawTemplate != nil {
return r.rawTemplate
}
return r.httpReq
}
func (r *Request) invalidatePreparedState() {
if r == nil {
return
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceRun = nil
r.traceState = nil
r.httpClient = nil
wasApplied := r.applied
r.applied = false
if !wasApplied || r.doRaw {
return
}
if err := r.rebuildPreparedRequestBase(); err != nil && r.err == nil {
r.err = err
}
}
func (r *Request) rebuildPreparedRequestBase() error {
if r == nil || r.doRaw {
return nil
}
ctx := r.ctx
if ctx == nil {
ctx = context.Background()
}
httpReq, err := http.NewRequestWithContext(ctx, r.method, r.url, nil)
if err != nil {
return wrapError(err, "rebuild http request")
}
r.httpReq = httpReq
r.syncRequestHost()
return nil
}
func (r *Request) rebuildRawRequestBase() error {
if r == nil || !r.doRaw {
return nil
}
baseReq := r.rawBaseRequest()
rawReq, err := cloneRawHTTPRequest(baseReq, normalizeContext(r.ctx))
if err != nil && baseReq != nil && baseReq == r.httpReq {
r.httpReq = baseReq.WithContext(normalizeContext(r.ctx))
return nil
}
if rawReq != nil {
r.httpReq = rawReq
}
return err
}
func (r *Request) rebuildExecutionRequestBase() error {
if r == nil {
return nil
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
r.execCtx = nil
r.traceState = nil
r.applied = false
if r.doRaw {
return r.rebuildRawRequestBase()
}
return r.rebuildPreparedRequestBase()
}
// newRequest 创建新请求(内部使用)
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
ctx = normalizeContext(ctx)
if method == "" {
method = http.MethodGet
}
method = strings.ToUpper(method)
// 创建 http.Request
httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
if err != nil {
return nil, wrapError(err, "create http request")
}
// 初始化配置
config := &RequestConfig{
Network: NetworkConfig{
DialTimeout: DefaultDialTimeout,
Timeout: DefaultTimeout,
},
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
}
// 设置默认 User-Agent
config.Headers.Set("User-Agent", DefaultUserAgent)
// POST 请求默认 Content-Type
if method == http.MethodPost {
config.Headers.Set("Content-Type", ContentTypeFormURLEncoded)
}
req := &Request{
ctx: ctx,
url: urlStr,
method: method,
config: config,
httpReq: httpReq,
autoFetch: DefaultFetchRespBody,
}
// 应用选项
for _, opt := range opts {
if opt != nil {
if err := opt(req); err != nil {
req.err = err
return req, nil // 不返回错误,累积到 req.err
}
}
}
return req, nil
}
// NewRequest 创建新请求
func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
return nil, err
}
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewRequestWithContext 创建新请求(带 context
func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return nil, err
}
// 新增
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
// 返回一个带错误的请求
return &Request{
ctx: context.Background(),
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// NewSimpleRequestWithContext 创建新请求(带 context,忽略错误)
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
ctx = normalizeContext(ctx)
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// Clone 克隆请求
func (r *Request) Clone() *Request {
cloned := &Request{
ctx: r.ctx,
url: r.url,
method: r.method,
err: r.err,
config: r.config.Clone(),
client: r.client,
httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
traceHooks: r.traceHooks,
traceRecorder: r.traceRecorder,
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
rawSourceExternal: r.rawSourceExternal,
}
// 重新创建 http.Request
if !r.doRaw {
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
} else {
rawTemplate, err := cloneRawHTTPRequest(r.rawBaseRequest(), cloned.ctx)
cloned.rawTemplate = rawTemplate
cloned.httpReq = rawTemplate
if err != nil && cloned.err == nil {
cloned.err = err
}
}
return cloned
}
// Err 获取累积的错误
func (r *Request) Err() error {
return r.err
}
// Context 获取 context
func (r *Request) Context() context.Context {
return r.ctx
}
// SetContext 设置 context
func (r *Request) SetContext(ctx context.Context) *Request {
return r.applyMutation(mutateContext(ctx))
}
// Method 获取 HTTP 方法
func (r *Request) Method() string {
return r.method
}
// SetMethod 设置 HTTP 方法
func (r *Request) SetMethod(method string) *Request {
if r.err != nil {
return r
}
method = strings.ToUpper(method)
if !validMethod(method) {
r.err = wrapError(ErrInvalidMethod, "method: %s", method)
return r
}
r.method = method
if r.httpReq != nil {
r.httpReq.Method = method
}
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate.Method = method
}
r.invalidatePreparedState()
return r
}
// URL 获取 URL
func (r *Request) URL() string {
return r.url
}
// SetURL 设置 URL
func (r *Request) SetURL(urlStr string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set URL when using raw request")
return r
}
u, err := url.Parse(urlStr)
if err != nil {
r.err = wrapError(ErrInvalidURL, "url: %s", urlStr)
return r
}
r.url = urlStr
u.Host = removeEmptyPort(u.Host)
r.httpReq.URL = u
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
func (r *Request) effectiveRequestHost() string {
if r == nil {
return ""
}
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil && r.httpReq.URL != nil {
return removeEmptyPort(r.httpReq.URL.Host)
}
if r.url == "" {
return ""
}
u, err := url.Parse(r.url)
if err != nil {
return ""
}
return removeEmptyPort(u.Host)
}
func (r *Request) syncRequestHost() {
if r == nil || r.httpReq == nil {
return
}
r.httpReq.Host = r.effectiveRequestHost()
}
// RawRequest 获取底层 http.Request
func (r *Request) RawRequest() *http.Request {
if r != nil && r.doRaw && r.rawTemplate != nil && !r.applied {
return r.rawTemplate
}
return r.httpReq
}
// SetRawRequest 设置底层 http.Request(启用原始模式)
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
return r.applyMutation(mutateRawRequest(httpReq))
}
// EnableRawMode 启用原始模式(不修改请求)
func (r *Request) EnableRawMode() *Request {
if r.doRaw {
return r
}
r.doRaw = true
r.invalidatePreparedState()
return r
}
// DisableRawMode 禁用原始模式
func (r *Request) DisableRawMode() *Request {
if !r.doRaw {
return r
}
if r.rawSourceExternal {
r.err = fmt.Errorf("cannot disable raw mode after SetRawRequest")
return r
}
r.doRaw = false
r.invalidatePreparedState()
return r
}
// SetAutoFetch 设置是否自动获取响应体
func (r *Request) SetAutoFetch(auto bool) *Request {
r.autoFetch = auto
return r
}
// HTTPClient 获取底层 http.Client(只读)
func (r *Request) HTTPClient() (*http.Client, error) {
if r.err != nil {
return nil, r.err
}
if r.httpClient != nil {
return r.httpClient, nil
}
// 如果还没构建,先准备
if err := r.prepare(); err != nil {
return nil, err
}
return r.httpClient, nil
}
// Client 获取关联的 Client(只读)
func (r *Request) Client() *Client {
return r.client
}
// Do 执行请求
func (r *Request) Do() (*Response, error) {
// 检查累积的错误
if r.err != nil {
return nil, r.err
}
r.startTraceExecution()
var (
resp *Response
err error
)
if r.hasRetryPolicy() {
resp, err = r.doWithRetry()
} else {
resp, err = r.doOnce()
}
r.finishTraceExecution(resp)
return resp, err
}
func (r *Request) doOnce() (*Response, error) {
if err := r.rebuildExecutionRequestBase(); err != nil {
return nil, wrapError(err, "rebuild execution request")
}
// 准备请求
if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request")
}
if r.traceRun != nil {
r.traceRun.observePreparedRequest(r.httpReq)
}
// 执行请求
httpResp, err := r.httpClient.Do(r.httpReq)
if err != nil {
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return &Response{
Response: &http.Response{},
request: r,
httpClient: r.httpClient,
body: &Body{},
}, wrapError(err, "do request")
}
if r.traceRun != nil {
r.traceRun.observeResponse(httpResp)
}
rawBody := httpResp.Body
if r.cancel != nil {
rawBody = &cancelReadCloser{
ReadCloser: httpResp.Body,
cancel: r.cancel,
}
}
// 创建响应
resp := &Response{
Response: httpResp,
request: r,
httpClient: r.httpClient,
cancel: r.cancel,
body: &Body{
raw: rawBody,
maxBytes: r.config.MaxRespBodyBytes,
},
}
r.cancel = nil
// 自动获取响应体
if r.autoFetch {
if err := resp.body.readAll(); err != nil {
_ = resp.Close()
return resp, err
}
}
return resp, nil
}
// Get 发送 GET 请求
func (r *Request) Get() (*Response, error) {
return r.SetMethod(http.MethodGet).Do()
}
// Post 发送 POST 请求
func (r *Request) Post() (*Response, error) {
return r.SetMethod(http.MethodPost).Do()
}
// Put 发送 PUT 请求
func (r *Request) Put() (*Response, error) {
return r.SetMethod(http.MethodPut).Do()
}
// Delete 发送 DELETE 请求
func (r *Request) Delete() (*Response, error) {
return r.SetMethod(http.MethodDelete).Do()
}
// Head 发送 HEAD 请求
func (r *Request) Head() (*Response, error) {
return r.SetMethod(http.MethodHead).Do()
}
// Patch 发送 PATCH 请求
func (r *Request) Patch() (*Response, error) {
return r.SetMethod(http.MethodPatch).Do()
}
// Options 发送 OPTIONS 请求
func (r *Request) Options() (*Response, error) {
return r.SetMethod(http.MethodOptions).Do()
}
// Trace 发送 TRACE 请求
func (r *Request) Trace() (*Response, error) {
return r.SetMethod(http.MethodTrace).Do()
}
// Connect 发送 CONNECT 请求
func (r *Request) Connect() (*Response, error) {
return r.SetMethod(http.MethodConnect).Do()
}
+220
View File
@@ -0,0 +1,220 @@
package starnet
import (
"encoding/json"
"io"
"os"
)
// SetBody 设置请求体(字节)
func (r *Request) SetBody(body []byte) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setBytesBodyConfig(&r.config.Body, body)
r.invalidatePreparedState()
return r
}
// SetBodyReader 设置请求体(Reader)。
// 出于避免重复写的保守策略,Reader 形态的 body 在非幂等方法上不会自动参与 retry。
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setReaderBodyConfig(&r.config.Body, reader)
r.invalidatePreparedState()
return r
}
// SetBodyString 设置请求体(字符串)
func (r *Request) SetBodyString(body string) *Request {
return r.SetBody([]byte(body))
}
// SetJSON 设置 JSON 请求体
func (r *Request) SetJSON(v interface{}) *Request {
if r.err != nil {
return r
}
data, err := json.Marshal(v)
if err != nil {
r.err = wrapError(err, "marshal json")
return r
}
return r.SetContentType(ContentTypeJSON).SetBody(data)
}
// SetFormData 设置表单数据(覆盖)
func (r *Request) SetFormData(data map[string][]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setFormBodyConfig(&r.config.Body, data)
r.invalidatePreparedState()
return r
}
// AddFormData 添加表单数据
func (r *Request) AddFormData(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
ensureFormMode(&r.config.Body)
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
r.invalidatePreparedState()
return r
}
// AddFormDataMap 批量添加表单数据
func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
ensureFormMode(&r.config.Body)
for key, value := range data {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
}
r.invalidatePreparedState()
return r
}
// AddFile 添加文件(从路径)
func (r *Request) AddFile(formName, filePath string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
// AddFileWithName 添加文件(指定文件名)
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
// AddFileWithType 添加文件(指定 MIME 类型)
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: fileType,
})
r.invalidatePreparedState()
return r
}
// AddFileStream 添加文件流
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
r.invalidatePreparedState()
return r
}
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
ensureMultipartMode(&r.config.Body)
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: fileType,
})
r.invalidatePreparedState()
return r
}
+34
View File
@@ -0,0 +1,34 @@
package starnet
import "net/http"
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
return r.applyMutation(mutateBasicAuth(username, password))
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
return r.applyMutation(mutateContentLength(length))
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
return r.applyMutation(mutateAutoCalcContentLength(auto))
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
return r.applyMutation(mutateTransport(transport))
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
return r.applyMutation(mutateUploadProgress(fn))
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
return r.applyMutation(mutateMaxRespBodyBytes(maxBytes))
}
+172
View File
@@ -0,0 +1,172 @@
package starnet
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
)
func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/path", http.MethodPost).
SetHeader("X-Test", "one").
AddQuery("q", "v").
SetBodyReader(strings.NewReader("payload"))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := r.Context().Err(); err != nil {
t.Fatalf("request context already done: %v", err)
}
if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" {
t.Fatalf("header values=%v", values)
}
if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" {
t.Fatalf("query values=%v", values)
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q", string(body))
}
n := atomic.AddInt32(&attempts, 1)
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))),
Request: r,
}, nil
}),
}}
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if err := resp1.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).
SetRawRequest(rawReq).
SetProxy("http://proxy.example:8080").
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true).
SetTLSServerName("override.example")
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
raw := req.execCtx.Value(ctxKeyRequestContext)
rc, ok := raw.(*RequestContext)
if !ok || rc == nil {
t.Fatalf("expected request context, got %#v", raw)
}
if rc.Proxy != "http://proxy.example:8080" {
t.Fatalf("proxy=%q", rc.Proxy)
}
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
t.Fatalf("custom ip=%v", rc.CustomIP)
}
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
t.Fatalf("tls config=%#v", rc.TLSConfig)
}
if rc.TLSServerName != "override.example" {
t.Fatalf("tls server name=%q", rc.TLSServerName)
}
}
func TestRequestSetFormDataOverridesBytesBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetBodyString("stale").
SetFormData(map[string][]string{"k": []string{"v"}})
if req.config.Body.Mode != bodyModeForm {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 {
t.Fatalf("unexpected stale body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
body, err := req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "k=v" {
t.Fatalf("body=%q; want k=v", string(data))
}
}
func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "payload.txt")
if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetJSON(map[string]string{"old": "json-only"}).
AddFile("file", filePath)
if req.config.Body.Mode != bodyModeMultipart {
t.Fatalf("body mode=%v", req.config.Body.Mode)
}
if req.config.Body.Reader != nil || req.config.Body.Bytes != nil {
t.Fatalf("unexpected stale simple body state: %#v", req.config.Body)
}
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
data, err := io.ReadAll(req.httpReq.Body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") {
t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type"))
}
if !strings.Contains(string(data), "file-body") {
t.Fatalf("multipart body missing file content: %q", string(data))
}
if strings.Contains(string(data), "json-only") {
t.Fatalf("multipart body still contains stale json: %q", string(data))
}
}
+264
View File
@@ -0,0 +1,264 @@
package starnet
import (
"net/http"
)
func isHostHeaderKey(key string) bool {
return http.CanonicalHeaderKey(key) == "Host"
}
func setRequestHostConfig(config *RequestConfig, host string) {
if config == nil {
return
}
if config.Headers == nil {
config.Headers = make(http.Header)
}
config.Host = host
if host == "" {
config.Headers.Del("Host")
return
}
config.Headers.Set("Host", host)
}
// SetHeader 设置 Header(覆盖)
func (r *Request) SetHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Set(key, value)
r.invalidatePreparedState()
return r
}
// AddHeader 添加 Header
func (r *Request) AddHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
return r.SetHost(value)
}
r.config.Headers.Add(key, value)
r.invalidatePreparedState()
return r
}
// SetHeaders 设置所有 Headers(覆盖)
func (r *Request) SetHeaders(headers http.Header) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers = cloneHeader(headers)
r.config.Host = r.config.Headers.Get("Host")
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// AddHeaders 批量添加 Headers
func (r *Request) AddHeaders(headers map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range headers {
if isHostHeaderKey(k) {
setRequestHostConfig(r.config, v)
continue
}
r.config.Headers.Add(k, v)
}
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// DeleteHeader 删除 Header
func (r *Request) DeleteHeader(key string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
if isHostHeaderKey(key) {
setRequestHostConfig(r.config, "")
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
r.config.Headers.Del(key)
r.invalidatePreparedState()
return r
}
// GetHeader 获取 Header
func (r *Request) GetHeader(key string) string {
if isHostHeaderKey(key) {
return r.config.Host
}
return r.config.Headers.Get(key)
}
// Headers 获取所有 Headers
func (r *Request) Headers() http.Header {
if r == nil || r.config == nil {
return make(http.Header)
}
return cloneHeader(r.config.Headers)
}
// SetHost 设置请求 Host 头覆盖。
func (r *Request) SetHost(host string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
setRequestHostConfig(r.config, host)
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// Host 获取显式 Host 覆盖。
func (r *Request) Host() string {
if r.config != nil && r.config.Host != "" {
return r.config.Host
}
if r.httpReq != nil {
return r.httpReq.Host
}
return ""
}
// SetContentType 设置 Content-Type
func (r *Request) SetContentType(contentType string) *Request {
return r.SetHeader("Content-Type", contentType)
}
// SetUserAgent 设置 User-Agent
func (r *Request) SetUserAgent(userAgent string) *Request {
return r.SetHeader("User-Agent", userAgent)
}
// SetReferer 设置 Referer
func (r *Request) SetReferer(referer string) *Request {
return r.SetHeader("Referer", referer)
}
// SetBearerToken 设置 Bearer Token
func (r *Request) SetBearerToken(token string) *Request {
return r.SetHeader("Authorization", "Bearer "+token)
}
// AddCookie 添加 Cookie
func (r *Request) AddCookie(cookie *http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = append(r.config.Cookies, cloneCookie(cookie))
r.invalidatePreparedState()
return r
}
// AddSimpleCookie 添加简单 Cookiepath 为 /
func (r *Request) AddSimpleCookie(name, value string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
// AddCookieKV 添加 Cookie(指定 path
func (r *Request) AddCookieKV(name, value, path string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: path,
})
}
// SetCookies 设置所有 Cookies(覆盖)
func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = cloneCookies(cookies)
r.invalidatePreparedState()
return r
}
// AddCookies 批量添加 Cookies
func (r *Request) AddCookies(cookies map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
r.invalidatePreparedState()
return r
}
// Cookies 获取所有 Cookies
func (r *Request) Cookies() []*http.Cookie {
if r == nil || r.config == nil {
return nil
}
return cloneCookies(r.config.Cookies)
}
// ResetHeaders 重置所有 Headers
func (r *Request) ResetHeaders() *Request {
if r.err != nil {
return r
}
r.config.Headers = make(http.Header)
r.config.Host = ""
r.syncRequestHost()
r.invalidatePreparedState()
return r
}
// ResetCookies 重置所有 Cookies
func (r *Request) ResetCookies() *Request {
if r.err != nil {
return r
}
r.config.Cookies = []*http.Cookie{}
r.invalidatePreparedState()
return r
}
+43
View File
@@ -0,0 +1,43 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestConvenienceMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
tests := []struct {
name string
method string
do func(r *Request) (*Response, error)
}{
{name: "Head", method: http.MethodHead, do: (*Request).Head},
{name: "Patch", method: http.MethodPatch, do: (*Request).Patch},
{name: "Options", method: http.MethodOptions, do: (*Request).Options},
{name: "Trace", method: http.MethodTrace, do: (*Request).Trace},
{name: "Connect", method: http.MethodConnect, do: (*Request).Connect},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := NewSimpleRequest(server.URL, http.MethodGet)
resp, err := tt.do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if got := resp.Header.Get("X-Method"); got != tt.method {
t.Fatalf("method=%s want=%s", got, tt.method)
}
})
}
}
+69
View File
@@ -0,0 +1,69 @@
package starnet
import (
"context"
"io"
"mime/multipart"
"os"
)
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody(execCtx context.Context) error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
go func() {
defer pw.Close()
defer writer.Close()
for key, values := range r.config.Body.FormData {
for _, value := range values {
if err := writer.WriteField(key, value); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
for _, file := range r.config.Body.Files {
if err := r.writeFile(execCtx, writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(execCtx context.Context, writer *multipart.Writer, file RequestFile) error {
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
_, err = copyWithProgress(execCtx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
+333
View File
@@ -0,0 +1,333 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"time"
)
type requestMutation func(*Request) error
func (r *Request) applyMutation(mutation requestMutation) *Request {
if r == nil || r.err != nil {
return r
}
if err := mutation(r); err != nil {
r.err = err
return r
}
r.invalidatePreparedState()
return r
}
func requestOptFromMutation(mutation requestMutation) RequestOpt {
return func(r *Request) error {
if r == nil {
return nil
}
return mutation(r)
}
}
func validateCustomIPs(ips []string) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
return nil
}
func validateCustomDNS(dnsServers []string) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
return nil
}
func parseProxyURL(proxy string) (*url.URL, error) {
if proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
if proxyURL.Scheme == "" {
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
}
if proxyURL.Host == "" {
return nil, fmt.Errorf("proxy host is required: %s", proxy)
}
return proxyURL, nil
}
func mutateTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
func mutateDialTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
func mutateProxy(proxy string) requestMutation {
return func(r *Request) error {
if _, err := parseProxyURL(proxy); err != nil {
return err
}
r.config.Network.Proxy = proxy
return nil
}
}
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
func mutateTLSServerName(serverName string) requestMutation {
return func(r *Request) error {
r.config.TLS.ServerName = serverName
return nil
}
}
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
return func(r *Request) error {
r.traceHooks = hooks
return nil
}
}
func mutateTraceRecorder(recorder *TraceRecorder) requestMutation {
return func(r *Request) error {
r.traceRecorder = recorder
return nil
}
}
func mutateSkipTLSVerify(skip bool) requestMutation {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
func mutateCustomIP(ips []string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs(ips); err != nil {
return err
}
r.config.DNS.CustomIP = cloneStringSlice(ips)
return nil
}
}
func mutateAddCustomIP(ip string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs([]string{ip}); err != nil {
return err
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
func mutateCustomDNS(dnsServers []string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS(dnsServers); err != nil {
return err
}
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
return nil
}
}
func mutateAddCustomDNS(dns string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS([]string{dns}); err != nil {
return err
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
func mutateBasicAuth(username, password string) requestMutation {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
func mutateContentLength(length int64) requestMutation {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
func mutateAutoCalcContentLength(auto bool) requestMutation {
return func(r *Request) error {
if r.doRaw {
return fmt.Errorf("cannot set auto calc content length in raw mode")
}
r.config.AutoCalcContentLength = auto
return nil
}
}
func mutateTransport(transport *http.Transport) requestMutation {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
func mutateAutoFetch(auto bool) requestMutation {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
func mutateContext(ctx context.Context) requestMutation {
return func(r *Request) error {
ctx = normalizeContext(ctx)
r.ctx = ctx
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate = r.rawTemplate.WithContext(ctx)
}
if r.httpReq != nil {
r.httpReq = r.httpReq.WithContext(ctx)
}
return nil
}
}
func mutateRawRequest(httpReq *http.Request) requestMutation {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.rawTemplate = httpReq
r.ctx = normalizeContext(httpReq.Context())
r.method = httpReq.Method
if httpReq.URL != nil {
r.url = httpReq.URL.String()
}
r.doRaw = true
r.rawSourceExternal = true
return nil
}
}
func mutateAddQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
func mutateSetQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = []string{value}
return nil
}
}
func mutateSetQueries(queries map[string][]string) requestMutation {
return func(r *Request) error {
r.config.Queries = cloneStringMapSlice(queries)
return nil
}
}
func mutateAddQueries(queries map[string]string) requestMutation {
return func(r *Request) error {
for key, value := range queries {
r.config.Queries[key] = append(r.config.Queries[key], value)
}
return nil
}
}
func mutateDeleteQuery(key string) requestMutation {
return func(r *Request) error {
delete(r.config.Queries, key)
return nil
}
}
func mutateDeleteQueryValue(key, value string) requestMutation {
return func(r *Request) error {
values, ok := r.config.Queries[key]
if !ok {
return nil
}
newValues := make([]string, 0, len(values))
for _, item := range values {
if item != value {
newValues = append(newValues, item)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
return nil
}
r.config.Queries[key] = newValues
return nil
}
}
+71
View File
@@ -0,0 +1,71 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateTimeout(timeout))
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
return r.applyMutation(mutateDialTimeout(timeout))
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
return r.applyMutation(mutateProxy(proxy))
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
return r.applyMutation(mutateDialFunc(fn))
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
return r.applyMutation(mutateTLSConfig(tlsConfig))
}
// SetTLSServerName 设置显式 TLS ServerName/SNI。
func (r *Request) SetTLSServerName(serverName string) *Request {
return r.applyMutation(mutateTLSServerName(serverName))
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
return r.applyMutation(mutateSkipTLSVerify(skip))
}
// SetCustomIP 设置自定义 IP(直接指定 IP,跳过 DNS)
func (r *Request) SetCustomIP(ips []string) *Request {
return r.applyMutation(mutateCustomIP(ips))
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
return r.applyMutation(mutateAddCustomIP(ip))
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
return r.applyMutation(mutateCustomDNS(dnsServers))
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
return r.applyMutation(mutateAddCustomDNS(dns))
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
return r.applyMutation(mutateLookupFunc(fn))
}
+346
View File
@@ -0,0 +1,346 @@
package starnet
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
)
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
if httpReq == nil {
return
}
httpReq.Body = io.NopCloser(bytes.NewReader(data))
httpReq.ContentLength = int64(len(data))
httpReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
}
func clearSimpleBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Bytes = nil
body.Reader = nil
}
func resetFormBodyState(body *BodyConfig) {
if body == nil {
return
}
body.FormData = make(map[string][]string)
}
func resetMultipartBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Files = nil
}
func setBytesBodyConfig(body *BodyConfig, data []byte) {
if body == nil {
return
}
body.Mode = bodyModeBytes
body.Bytes = cloneBytes(data)
body.Reader = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
if body == nil {
return
}
body.Mode = bodyModeReader
body.Reader = reader
body.Bytes = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
if body == nil {
return
}
body.Mode = bodyModeForm
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = cloneStringMapSlice(data)
}
func ensureFormMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = make(map[string][]string)
body.Mode = bodyModeForm
}
func ensureMultipartMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
if body.Mode != bodyModeForm {
clearSimpleBodyState(body)
body.FormData = make(map[string][]string)
}
body.Mode = bodyModeMultipart
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
}
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
// applyBody 应用请求体
func (r *Request) applyBody(execCtx context.Context) error {
r.httpReq.Body = nil
r.httpReq.GetBody = nil
r.httpReq.ContentLength = 0
switch r.config.Body.Mode {
case bodyModeReader:
if r.config.Body.Reader == nil {
return nil
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
case *bytes.Reader:
data, err := snapshotBytesReader(reader)
if err != nil {
return wrapError(err, "snapshot bytes reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
case *strings.Reader:
data, err := snapshotStringReader(reader)
if err != nil {
return wrapError(err, "snapshot strings reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
default:
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(reader.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(reader.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(reader.Len())
}
return nil
case bodyModeBytes:
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
return nil
case bodyModeMultipart:
return r.applyMultipartBody(execCtx)
case bodyModeForm:
values := url.Values{}
for key, items := range r.config.Body.FormData {
for _, value := range items {
values.Add(key, value)
}
}
encoded := values.Encode()
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
return nil
}
return nil
}
func buildTraceTLSHandshakeInfo(req *http.Request, execCtx context.Context, defaultServerName string) TraceTLSHandshakeStartInfo {
if req == nil || req.URL == nil || req.URL.Scheme != "https" {
return TraceTLSHandshakeStartInfo{}
}
reqCtx := getRequestContext(execCtx)
info := TraceTLSHandshakeStartInfo{}
// 自定义 DialFunc 的真实落点由调用方决定,这里只在默认拨号路径下预填地址,避免 trace 元信息误导。
if reqCtx == nil || reqCtx.DialFn == nil {
info.Network = "tcp"
info.Addr = req.URL.Host
}
if reqCtx != nil {
if reqCtx.TLSConfig != nil && reqCtx.TLSConfig.ServerName != "" {
info.ServerName = reqCtx.TLSConfig.ServerName
} else if reqCtx.TLSServerName != "" {
info.ServerName = reqCtx.TLSServerName
}
}
if info.ServerName == "" {
if defaultServerName != "" {
info.ServerName = defaultServerName
} else {
info.ServerName = req.URL.Hostname()
}
}
return info
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() (err error) {
if r.applied {
return nil
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
execCtx := r.ctx
if execCtx == nil {
execCtx = context.Background()
}
defaultTLSServerName := ""
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
defaultTLSServerName = r.httpReq.URL.Hostname()
}
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
var traceState *traceState
traceHooks := composeTraceHooks(r.traceHooks, traceRecorderHooks(r.traceRun))
if traceHooks != nil {
traceState = newTraceState(traceHooks)
traceState.setDefaultTLSHandshakeInfo(buildTraceTLSHandshakeInfo(r.httpReq, execCtx, defaultTLSServerName))
execCtx = withTraceState(execCtx, traceState)
if clientTrace := traceState.clientTrace(); clientTrace != nil {
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
}
}
var cancel context.CancelFunc
if r.config.Network.Timeout > 0 {
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
defer func() {
if err != nil && cancel != nil {
cancel()
}
}()
if r.httpClient == nil {
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err
}
}
if !r.doRaw {
if len(r.config.Queries) > 0 {
query := r.httpReq.URL.Query()
for key, values := range r.config.Queries {
for _, value := range values {
query.Add(key, value)
}
}
r.httpReq.URL.RawQuery = query.Encode()
}
for key, values := range r.config.Headers {
if isHostHeaderKey(key) {
continue
}
for _, value := range values {
r.httpReq.Header.Add(key, value)
}
}
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
if err := r.applyBody(execCtx); err != nil {
return err
}
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
setReplayableRequestBodyBytes(r.httpReq, data)
}
r.syncRequestHost()
}
r.execCtx = execCtx
r.traceState = traceState
r.cancel = cancel
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
if r.client != nil {
return r.client.HTTPClient(), nil
}
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
return DefaultHTTPClient(), nil
}
+335
View File
@@ -0,0 +1,335 @@
package starnet
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetHeader("X-Test", "one").
SetBodyString("first")
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
Request: r,
}, nil
}),
}}
if _, err := req.HTTPClient(); err != nil {
t.Fatalf("HTTPClient() error: %v", err)
}
req.SetHeader("X-Test", "two").SetBodyString("second")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != "two:second" {
t.Fatalf("body=%q; want %q", body, "two:second")
}
}
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
}
select {
case <-time.After(50 * time.Millisecond):
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
case <-r.Context().Done():
return nil, r.Context().Err()
}
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
_ = resp.Close()
_, err = req.SetTimeout(10 * time.Millisecond).Do()
if err == nil {
t.Fatal("second Do() succeeded; want timeout error")
}
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("second Do() error=%v; want timeout", err)
}
}
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost)
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, pr)
_ = pr.Close()
close(done)
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := req.writeFile(ctx, writer, RequestFile{
FormName: "file",
FileName: "payload.txt",
FileData: strings.NewReader("payload"),
FileSize: int64(len("payload")),
})
_ = writer.Close()
_ = pw.Close()
<-done
if !errors.Is(err, context.Canceled) {
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
}
}
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
}
}
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
tests := []struct {
name string
req *Request
want string
}{
{
name: "bytes",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
want: "payload",
},
{
name: "bytes-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
want: "payload",
},
{
name: "strings-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
want: "payload",
},
{
name: "form-data",
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
want: "k=v",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
if tt.req.httpReq.GetBody == nil {
t.Fatal("GetBody is nil")
}
body, err := tt.req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != tt.want {
t.Fatalf("body=%q; want %q", string(data), tt.want)
}
})
}
}
type replayRoundTripper struct {
attempts int
bodies []string
}
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
rt.attempts++
rt.bodies = append(rt.bodies, string(body))
if rt.attempts == 1 {
return nil, errors.New("first target failed")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err != nil {
t.Fatalf("roundTripResolvedTargets() error: %v", err)
}
defer resp.Body.Close()
if len(rt.bodies) != 2 {
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
}
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
}
}
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err == nil {
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
}
if len(rt.bodies) != 1 {
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
}
if rt.bodies[0] != "payload" {
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
}
}
func TestRetryReplayableReaderBody(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload")).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q; want payload", string(body))
}
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestWithProxyInvalidReturnsError(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewRequest() succeeded; want invalid proxy error")
}
}
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
client := NewClientNoErr()
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
}
}
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
_, err := NewClient(WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewClient() succeeded; want invalid proxy error")
}
}
+31
View File
@@ -0,0 +1,31 @@
package starnet
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
return r.applyMutation(mutateAddQuery(key, value))
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
return r.applyMutation(mutateSetQuery(key, value))
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
return r.applyMutation(mutateSetQueries(queries))
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
return r.applyMutation(mutateAddQueries(queries))
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
return r.applyMutation(mutateDeleteQuery(key))
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
return r.applyMutation(mutateDeleteQueryValue(key, value))
}
+168
View File
@@ -0,0 +1,168 @@
package starnet
import (
"io"
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
)
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestSetContextNilUsesBackground(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("request context is nil")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.SetContext(nil).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if req.Context() == nil {
t.Fatal("request Context() is nil")
}
}
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
var hits int32
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("retry request context is nil")
}
if atomic.AddInt32(&hits, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.
SetTimeout(DefaultTimeout).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("hits=%d; want 2", got)
}
}
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
rawReq.Header.Set("X-Test", "one")
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() != nil {
t.Fatalf("Clone() err = %v", cloned.Err())
}
if cloned.RawRequest() == rawReq {
t.Fatal("raw request pointer reused")
}
cloned.RawRequest().Header.Set("X-Test", "two")
if rawReq.Header.Get("X-Test") != "one" {
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
}
body, err := cloned.RawRequest().GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "payload" {
t.Fatalf("body=%q; want payload", string(data))
}
}
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
rawReq := &http.Request{
Method: http.MethodPost,
URL: mustParseURL(t, "http://example.com/upload"),
Header: make(http.Header),
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
}
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() == nil {
t.Fatal("Clone() should fail for non-replayable raw body")
}
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
}
}
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
if req.Err() == nil {
t.Fatal("DisableRawMode() should set error")
}
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
t.Fatalf("DisableRawMode() err=%v", req.Err())
}
if !req.doRaw {
t.Fatal("request should remain in raw mode")
}
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
parsed, err := url.Parse(raw)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
return parsed
}
+172
View File
@@ -0,0 +1,172 @@
package starnet
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewSimpleRequest(t *testing.T) {
tests := []struct {
name string
url string
method string
expectErr bool
}{
{
name: "valid GET request",
url: "https://example.com",
method: "GET",
expectErr: false,
},
{
name: "valid POST request",
url: "https://example.com",
method: "POST",
expectErr: false,
},
{
name: "invalid URL",
url: "://invalid",
method: "GET",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.url, tt.method)
if tt.expectErr {
if err == nil && req.Err() == nil {
t.Errorf("NewRequest() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("NewRequest() unexpected error: %v", err)
}
if req.Method() != strings.ToUpper(tt.method) {
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
}
}
})
}
}
func TestRequestMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req := NewSimpleRequest(server.URL, method)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
if method != "HEAD" {
body, _ := resp.Body().String()
if body != method {
t.Errorf("Body = %v; want %v", body, method)
}
}
})
}
}
func TestRequestSetMethod(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetMethod("POST")
if req.Method() != "POST" {
t.Errorf("Method = %v; want POST", req.Method())
}
req.SetMethod("invalid method!")
if req.Err() == nil {
t.Error("SetMethod with invalid method should set error")
}
}
func TestRequestSetURL(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetURL("https://newexample.com")
if req.URL() != "https://newexample.com" {
t.Errorf("URL = %v; want https://newexample.com", req.URL())
}
req2 := NewSimpleRequest("https://example.com", "GET")
req2.SetURL("://invalid")
if req2.Err() == nil {
t.Error("SetURL with invalid URL should set error")
}
}
func TestRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Test") != "value" {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Test", "value")
// 第一次请求
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
resp.Close()
// 克隆请求
cloned := req.Clone()
cloned.SetHeader("X-Extra", "extra")
// 克隆的请求应该也能成功
resp2, err := cloned.Do()
if err != nil {
t.Fatalf("Cloned Do() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusOK {
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
}
}
func TestRequestContext(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
}
+56
View File
@@ -0,0 +1,56 @@
package starnet
// SetTraceHooks 设置请求 trace 回调。
func (r *Request) SetTraceHooks(hooks *TraceHooks) *Request {
return r.applyMutation(mutateTraceHooks(hooks))
}
// SetTraceRecorder 设置请求级 trace 摘要记录器。
// 记录器会保存最近一次已完成请求的摘要;若多个请求共享同一个记录器,则以最后一次完成的请求为准。
func (r *Request) SetTraceRecorder(recorder *TraceRecorder) *Request {
return r.applyMutation(mutateTraceRecorder(recorder))
}
// TraceSummary 返回当前请求最近一次执行的 trace 摘要快照。
func (r *Request) TraceSummary() *TraceSummary {
if r == nil || r.lastTraceSummary == nil {
return nil
}
summary := cloneTraceSummary(*r.lastTraceSummary)
return &summary
}
func (r *Request) startTraceExecution() {
if r == nil {
return
}
r.traceRun = nil
if r.traceRecorder == nil {
return
}
r.traceRun = r.traceRecorder.forkExecution()
if r.traceRun != nil {
r.traceRun.startRequest()
}
}
func (r *Request) finishTraceExecution(resp *Response) {
if r == nil {
return
}
if r.traceRun == nil {
r.lastTraceSummary = nil
if resp != nil {
resp.traceSummary = nil
}
return
}
summary := r.traceRun.Summary()
r.lastTraceSummary = cloneTraceSummaryPtr(summary)
r.traceRecorder.publishSummary(summary)
if resp != nil {
resp.traceSummary = cloneTraceSummaryPtr(summary)
}
r.traceRun = nil
}
+209
View File
@@ -0,0 +1,209 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"sync"
)
// Response HTTP 响应
type Response struct {
*http.Response
request *Request
httpClient *http.Client
cancel func()
body *Body
traceSummary *TraceSummary
}
// Body 响应体
type Body struct {
raw io.ReadCloser
data []byte
consumed bool
maxBytes int64
mu sync.Mutex
}
type cancelReadCloser struct {
io.ReadCloser
cancel func()
once sync.Once
}
func (c *cancelReadCloser) Close() error {
err := c.ReadCloser.Close()
c.once.Do(func() {
if c.cancel != nil {
c.cancel()
}
})
return err
}
// Request 获取原始请求
func (r *Response) Request() *Request {
return r.request
}
// TraceSummary 获取当前响应对应的 trace 摘要快照。
func (r *Response) TraceSummary() *TraceSummary {
if r == nil || r.traceSummary == nil {
return nil
}
summary := cloneTraceSummary(*r.traceSummary)
return &summary
}
// Body 获取响应体
func (r *Response) Body() *Body {
return r.body
}
// Close 关闭响应体
func (r *Response) Close() error {
if r == nil {
return nil
}
if r.body != nil && r.body.raw != nil {
return r.body.raw.Close()
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return nil
}
// CloseWithClient 关闭响应体并关闭空闲连接
func (r *Response) CloseWithClient() error {
if r == nil {
return nil
}
if r.httpClient != nil {
r.httpClient.CloseIdleConnections()
}
return r.Close()
}
// readAll 读取所有数据
func (b *Body) readAll() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
return nil
}
if b.raw == nil {
b.consumed = true
return nil
}
reader := io.Reader(b.raw)
if b.maxBytes > 0 {
reader = io.LimitReader(b.raw, b.maxBytes+1)
}
data, err := io.ReadAll(reader)
if err != nil {
return wrapError(err, "read response body")
}
if b.maxBytes > 0 && int64(len(data)) > b.maxBytes {
b.consumed = true
_ = b.raw.Close()
return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes)
}
b.data = data
b.consumed = true
_ = b.raw.Close()
return nil
}
// Bytes 获取响应体字节
func (b *Body) Bytes() ([]byte, error) {
if err := b.readAll(); err != nil {
return nil, err
}
return b.data, nil
}
// String 获取响应体字符串
func (b *Body) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
}
return string(data), nil
}
// JSON 解析 JSON 响应
func (b *Body) JSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// Reader 获取 Reader(只能调用一次)
func (b *Body) Reader() (io.ReadCloser, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
if b.data != nil {
// 已读取,返回缓存数据的 Reader
return io.NopCloser(bytes.NewReader(b.data)), nil
}
return nil, ErrBodyAlreadyConsumed
}
b.consumed = true
return b.raw, nil
}
// IsConsumed 检查是否已消费
func (b *Body) IsConsumed() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.consumed
}
// Close 关闭 Body
func (b *Body) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.raw != nil {
return b.raw.Close()
}
return nil
}
// MustBytes 获取响应体字节(忽略错误,失败返回 nil)
func (b *Body) MustBytes() []byte {
data, err := b.Bytes()
if err != nil {
return nil
}
return data
}
// MustString 获取响应体字符串(忽略错误,失败返回空串)
func (b *Body) MustString() string {
s, err := b.String()
if err != nil {
return ""
}
return s
}
// Unmarshal 解析 JSON 响应(兼容旧 API)
func (b *Body) Unmarshal(v interface{}) error {
return b.JSON(v)
}
+84
View File
@@ -0,0 +1,84 @@
package starnet
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestWithMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
resp, err := Get(server.URL, WithMaxRespBodyBytes(4))
if err != nil {
t.Fatalf("unexpected request error: %v", err)
}
defer resp.Close()
_, err = resp.Body().Bytes()
if err == nil {
t.Fatal("expected body too large error")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("1234"))
}))
defer server.Close()
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetMaxRespBodyBytes(4).
Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("unexpected read error: %v", err)
}
if body != "1234" {
t.Fatalf("body=%q want=1234", body)
}
}
func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
_, err := NewSimpleRequest(server.URL, http.MethodGet).
SetAutoFetch(true).
SetMaxRespBodyBytes(4).
Do()
if err == nil {
t.Fatal("expected body too large error with auto fetch")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytesInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1)
if req.Err() == nil {
t.Fatal("expected error for negative max bytes")
}
}
func TestWithMaxRespBodyBytesInvalid(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1))
if err == nil {
t.Fatal("expected error for negative max bytes")
}
}
+179
View File
@@ -0,0 +1,179 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestResponseBody(t *testing.T) {
testData := "test response data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// Test String()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
// Test multiple reads (should work because body is cached)
body2, err := resp.Body().String()
if err != nil {
t.Fatalf("Second Body().String() error: %v", err)
}
if body2 != testData {
t.Errorf("Second Body = %v; want %v", body2, testData)
}
}
func TestResponseJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"name":"John","age":30}`))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result struct {
Name string `json:"name"`
Age int `json:"age"`
}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("Body().JSON() error: %v", err)
}
if result.Name != "John" {
t.Errorf("Name = %v; want John", result.Name)
}
if result.Age != 30 {
t.Errorf("Age = %v; want 30", result.Age)
}
}
func TestResponseBytes(t *testing.T) {
testData := []byte("binary data")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(testData)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().Bytes()
if err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if string(body) != string(testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseReader(t *testing.T) {
testData := "stream data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
reader, err := resp.Body().Reader()
if err != nil {
t.Fatalf("Body().Reader() error: %v", err)
}
defer reader.Close()
body, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(body) != testData {
t.Errorf("Body = %v; want %v", string(body), testData)
}
}
func TestResponseAutoFetch(t *testing.T) {
testData := "auto fetch data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
// With auto fetch
resp, err := Get(server.URL, WithAutoFetch(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if !resp.Body().IsConsumed() {
t.Error("Body should be consumed with auto fetch")
}
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"Created", http.StatusCreated},
{"BadRequest", http.StatusBadRequest},
{"NotFound", http.StatusNotFound},
{"InternalServerError", http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != tt.statusCode {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode)
}
})
}
}
+468
View File
@@ -0,0 +1,468 @@
package starnet
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"strings"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
// WithRetry 为请求启用自动重试。
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
// SetRetry 为请求启用自动重试。
// 默认只重试幂等方法;即使显式关闭幂等限制,Reader 形态的 body 仍会对非幂等方法保持保守禁用,
// 以避免请求体已落地后再次发送。
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := normalizeContext(r.ctx)
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(retryCtx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptNo := attempt + 1
emitRetryAttemptStart(r.traceHooks, TraceRetryAttemptStartInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
})
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
willRetry := policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx)
statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}
emitRetryAttemptDone(r.traceHooks, TraceRetryAttemptDoneInfo{
Attempt: attemptNo,
MaxAttempts: maxAttempts,
StatusCode: statusCode,
Err: err,
WillRetry: willRetry,
})
if !willRetry {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
emitRetryBackoff(r.traceHooks, TraceRetryBackoffInfo{
Attempt: attemptNo,
Delay: delay,
})
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
attempt.traceRun = r.traceRun
attempt.lastTraceSummary = nil
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
raw, err := cloneRawHTTPRequest(r.httpReq, ctx)
if err != nil {
return nil, err
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
if hasReaderRequestBody(r) && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
return isReplayableConfiguredBody(r.config.Body)
}
func hasReaderRequestBody(r *Request) bool {
if r == nil || r.config == nil {
return false
}
return r.config.Body.Mode == bodyModeReader && r.config.Body.Reader != nil
}
func isReplayableConfiguredBody(body BodyConfig) bool {
switch body.Mode {
case bodyModeReader:
return isReplayableBodyReader(body.Reader)
case bodyModeMultipart:
for _, file := range body.Files {
if file.FileData != nil || file.FilePath == "" {
return false
}
}
}
return true
}
func isReplayableBodyReader(reader io.Reader) bool {
switch reader.(type) {
case *bytes.Buffer, *bytes.Reader, *strings.Reader:
return true
default:
return false
}
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}
+298
View File
@@ -0,0 +1,298 @@
package starnet
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithRetrySmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
resp, err := Get(s.URL,
WithRetry(2,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 3 {
t.Fatalf("hits=%d want=3", hits)
}
}
func TestWithRetryResponseRequestPointerStable(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.Request() != req {
t.Fatal("response request pointer should point to original request")
}
}
func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyReader(strings.NewReader("payload")).
SetRetry(3,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable)
}
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 2 {
t.Fatalf("hits=%d want=2", hits)
}
}
func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw")))
if rawReq.GetBody != nil {
t.Fatal("raw request GetBody should be nil in this test")
}
req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)).
SetRetry(2,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
time.Sleep(80 * time.Millisecond)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetTimeout(120*time.Millisecond).
SetRetry(3,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
_, err := req.Do()
if err == nil {
t.Fatal("expected timeout error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context deadline exceeded, got: %v", err)
}
if h := atomic.LoadInt32(&hits); h > 2 {
t.Fatalf("hits=%d want<=2 under tight timeout budget", h)
}
}
func TestSetRetryInvalidMax(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1)
if req.Err() == nil {
t.Fatal("expected error for negative retry max")
}
}
func TestSetRetrySeriesSmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(2).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryStatuses(http.StatusTooManyRequests)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 3 {
t.Fatalf("hits=%d want=3", h)
}
}
func TestSetRetryIdempotentOnlyWithPost(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 2 {
t.Fatalf("hits=%d want=2", h)
}
}
func TestSetRetryOnErrorOverridesDefault(t *testing.T) {
var dials int32
dialErr := errors.New("dial failed")
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.AddInt32(&dials, 1)
return nil, dialErr
}).
SetRetry(1).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryOnError(func(err error) bool {
return true
})
_, err := req.Do()
if err == nil {
t.Fatal("expected error")
}
if h := atomic.LoadInt32(&dials); h != 2 {
t.Fatalf("dial attempts=%d want=2", h)
}
}
func TestSetRetryOptionRequireEnableRetry(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2)
if req.Err() == nil {
t.Fatal("expected error when setting retry options before SetRetry")
}
if !strings.Contains(req.Err().Error(), "call SetRetry first") {
t.Fatalf("unexpected error: %v", req.Err())
}
}
+244
View File
@@ -0,0 +1,244 @@
package starnet
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
firstTarget := net.JoinHostPort("127.0.0.2", port)
secondTarget := net.JoinHostPort("127.0.0.1", port)
var (
mu sync.Mutex
connectTargets []string
)
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
mu.Lock()
connectTargets = append(connectTargets, r.Host)
mu.Unlock()
if r.Host == firstTarget {
http.Error(w, "first target failed", http.StatusBadGateway)
return
}
targetConn, err := net.Dial("tcp", r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
defer proxyServer.Close()
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if len(connectTargets) != 2 {
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
}
if connectTargets[0] != firstTarget {
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
}
if connectTargets[1] != secondTarget {
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
}
}
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var (
mu sync.Mutex
dnsStartCount int
dnsDoneCount int
lastHost string
)
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
lastHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if lastHost != "localhost" {
t.Fatalf("lastHost=%q; want localhost", lastHost)
}
}
func TestRequestHeadersReturnsCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetHeader("X-Test", "one").
SetHost("origin.example")
headers := req.Headers()
headers.Set("X-Test", "two")
headers.Set("Host", "mutated.example")
if got := req.GetHeader("X-Test"); got != "one" {
t.Fatalf("request header=%q; want one", got)
}
if got := req.Host(); got != "origin.example" {
t.Fatalf("request host=%q; want origin.example", got)
}
}
func TestRequestCookiesIsolation(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
source := []*http.Cookie{{
Name: "session",
Value: "one",
Path: "/",
}}
req.SetCookies(source)
source[0].Value = "mutated-outside"
got := req.Cookies()
if len(got) != 1 || got[0].Value != "one" {
t.Fatalf("cookies after SetCookies=%v", got)
}
got[0].Value = "mutated-copy"
if latest := req.Cookies()[0].Value; latest != "one" {
t.Fatalf("internal cookie mutated via getter, got %q", latest)
}
cookie := &http.Cookie{Name: "auth", Value: "token"}
req.ResetCookies().AddCookie(cookie)
cookie.Value = "changed"
latest := req.Cookies()
if len(latest) != 1 || latest[0].Value != "token" {
t.Fatalf("cookies after AddCookie=%v", latest)
}
}
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var dnsStartCount int
var dnsDoneCount int
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
dnsStartCount++
},
DNSDone: func(info TraceDNSDoneInfo) {
dnsDoneCount++
},
}
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if dnsStartCount != 1 || dnsDoneCount != 1 {
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
}
}
+115
View File
@@ -0,0 +1,115 @@
package starnet
import (
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
client := NewClientNoErr()
baseTimeout := client.HTTPClient().Timeout
_, err := client.Get(s.URL, WithTimeout(80*time.Millisecond))
if err == nil {
t.Fatal("expected request timeout error")
}
if client.HTTPClient().Timeout != baseTimeout {
t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout)
}
resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond))
if err != nil {
t.Fatalf("second request should succeed with larger timeout: %v", err)
}
defer resp.Close()
}
func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
client := NewClientNoErr()
resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond))
if err != nil {
t.Fatalf("first request should succeed: %v", err)
}
_ = resp1.Close()
// 如果请求超时依赖连接级绝对 deadline,经过该等待后复用连接会出现误超时。
time.Sleep(220 * time.Millisecond)
resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second))
if err != nil {
t.Fatalf("second request should not be affected by previous timeout window: %v", err)
}
_ = resp2.Close()
}
func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
time.Sleep(180 * time.Millisecond)
_, _ = conn.Write([]byte("x"))
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial error: %v", err)
}
defer c.Close()
if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil {
t.Fatalf("set read deadline error: %v", err)
}
buf := make([]byte, 1)
_, err = c.Read(buf)
if err == nil {
t.Fatal("expected read timeout error")
}
if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
t.Fatalf("expected net timeout error, got: %v", err)
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
t.Fatalf("clear read deadline error: %v", err)
}
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read after clearing deadline should succeed: %v", err)
}
if string(buf) != "x" {
t.Fatalf("unexpected payload: %q", string(buf))
}
<-done
}
+66
View File
@@ -0,0 +1,66 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Should timeout
req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
// Should succeed
req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond)
resp, err := req2.Do()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if resp != nil {
resp.Close()
}
}
func TestRequestDialTimeout(t *testing.T) {
// Use a non-routable IP to test dial timeout
req := NewSimpleRequest("http://192.0.2.1:80", "GET").
SetDialTimeout(100 * time.Millisecond)
start := time.Now()
_, err := req.Do()
elapsed := time.Since(start)
if err == nil {
t.Error("Expected dial timeout error, got nil")
}
// Should timeout within reasonable time (not wait forever)
if elapsed > 2*time.Second {
t.Errorf("Dial timeout took too long: %v", elapsed)
}
}
func TestClientTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr(WithTimeout(100 * time.Millisecond))
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}
+336
View File
@@ -0,0 +1,336 @@
package starnet
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestSkipTLSVerify(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail)
req := NewSimpleRequest(server.URL, "GET")
_, err := req.Do()
if err == nil {
t.Error("Expected TLS error without skip verify, got nil")
}
// With skip verify (should succeed)
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
resp, err := req2.Do()
if err != nil {
t.Fatalf("Do() with skip verify error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}
func TestRequestCustomTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
}
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientDefaultTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestLevelTLSOverride(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Client level: skip verify = false
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(false)
// Request level: skip verify = true (should override)
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestTls(t *testing.T) {
var requestCount int
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
switch requestCount {
case 1:
if r.Header.Get("Hello") != "" {
t.Fatalf("unexpected hello header on first request: %q", r.Header.Get("Hello"))
}
if auth := r.Header.Get("Authorization"); auth != "" {
t.Fatalf("unexpected authorization on first request: %q", auth)
}
case 2:
if got := r.Header.Get("Hello"); got != "world" {
t.Fatalf("hello header=%q; want world", got)
}
if got := r.Header.Get("Authorization"); got != "Bearer ddddddd" {
t.Fatalf("authorization=%q; want bearer token", got)
}
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
}))
defer server.Close()
localURL := httpsURLForHost(t, server, "localhost")
resp, err := NewSimpleRequest(localURL, "GET").
SetTLSConfig(&tls.Config{RootCAs: pool}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
resp, err = client.NewSimpleRequest(localURL, "GET",
WithTLSConfig(&tls.Config{RootCAs: pool}),
WithHeader("hello", "world"),
WithContext(context.Background()),
WithBearerToken("ddddddd")).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
}
func TestTLSWithProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("proxied"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithTimeout(10*time.Second),
WithProxy(proxy.URL),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 {
t.Fatalf("proxy targets=%v; want 1 target", targets)
}
t.Log(resp.Status)
}
func TestTLSWithProxyBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "proxy-bug.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newIPv4ConnectProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 关键:使用 WithProxy 触发 needsDynamicTransport
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
req, err := client.NewRequest(httpsURLForHost(t, server, "proxy-bug.test"), "GET",
WithTimeout(10*time.Second),
WithProxy(proxy.URL),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前会报:tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("proxy targets=%v", targets)
}
t.Logf("Status: %s", resp.Status)
}
// 更精准的复现:直接测试有问题的分支
func TestTLSDialWithoutServerName(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "custom-ip.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
req, err := client.NewRequest(httpsURLForHost(t, server, "custom-ip.test"), "GET",
WithTimeout(10*time.Second),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
// 最小复现:只要触发 needsDynamicTransport 即可
func TestMinimalTLSBug(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// WithDialTimeout 也会触发动态 transport
req, err := client.NewRequest(httpsURLForHost(t, server, "localhost"), "GET",
WithDialTimeout(5*time.Second),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前必现:tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
func TestTLSWithSOCKS5ProxyPath(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "socks5-proxy.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
proxy := newSOCKS5ProxyServer(t, nil)
defer proxy.Close()
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest(httpsURLForHost(t, server, "socks5-proxy.test"), "GET",
WithTimeout(10*time.Second),
WithProxy(proxy.URL()),
WithCustomIP([]string{"127.0.0.1"}),
WithTLSConfig(&tls.Config{RootCAs: pool}),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
if targets := proxy.Targets(); len(targets) != 1 || targets[0] == "" {
t.Fatalf("socks5 targets=%v", targets)
}
t.Logf("Status: %s", resp.Status)
}
+90
View File
@@ -0,0 +1,90 @@
package starnet
import (
"crypto/tls"
"net"
"time"
)
// GetConfigForClientFunc selects TLS config by hostname/SNI.
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
// ClientHelloMeta carries sniffed TLS metadata and connection context.
type ClientHelloMeta struct {
ServerName string
LocalAddr net.Addr
RemoteAddr net.Addr
SupportedProtos []string
SupportedVersions []uint16
CipherSuites []uint16
}
// Clone returns a detached copy safe for callers to mutate.
func (m *ClientHelloMeta) Clone() *ClientHelloMeta {
if m == nil {
return nil
}
out := *m
if m.SupportedProtos != nil {
out.SupportedProtos = append([]string(nil), m.SupportedProtos...)
}
if m.SupportedVersions != nil {
out.SupportedVersions = append([]uint16(nil), m.SupportedVersions...)
}
if m.CipherSuites != nil {
out.CipherSuites = append([]uint16(nil), m.CipherSuites...)
}
return &out
}
// GetConfigForClientHelloFunc selects TLS config by sniffed TLS metadata.
type GetConfigForClientHelloFunc func(hello *ClientHelloMeta) (*tls.Config, error)
// ListenerConfig controls listener behavior.
type ListenerConfig struct {
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
BaseTLSConfig *tls.Config
// GetConfigForClient selects TLS config for a hostname/SNI.
// Deprecated: prefer GetConfigForClientHello for richer context.
GetConfigForClient GetConfigForClientFunc
// GetConfigForClientHello selects TLS config for sniffed TLS metadata.
GetConfigForClientHello GetConfigForClientHelloFunc
// AllowNonTLS allows plain TCP fallback.
AllowNonTLS bool
// SniffTimeout bounds protocol sniffing time. 0 means no timeout.
SniffTimeout time.Duration
// MaxClientHelloBytes limits buffered sniff data.
// If <= 0, default 64KiB.
MaxClientHelloBytes int
// Logger is optional.
Logger Logger
}
// DefaultListenerConfig returns a conservative default config.
func DefaultListenerConfig() ListenerConfig {
return ListenerConfig{
AllowNonTLS: false,
SniffTimeout: 5 * time.Second,
MaxClientHelloBytes: 64 * 1024,
}
}
// TLSDefaults returns a TLS config baseline.
// Caller should set Certificates / GetCertificate as needed.
func TLSDefaults() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// DialConfig controls dialing behavior.
type DialConfig struct {
Timeout time.Duration
LocalAddr net.Addr
}
+543
View File
@@ -0,0 +1,543 @@
package starnet
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"net"
"sync"
"time"
"b612.me/starnet/internal/tlssniffercore"
)
// replayConn replays buffered bytes first, then reads from live conn.
type replayConn struct {
reader io.Reader
conn net.Conn
}
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
return &replayConn{
reader: io.MultiReader(buffered, conn),
conn: conn,
}
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c *replayConn) Close() error { return c.conn.Close() }
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// SniffResult describes protocol sniffing result.
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
}
// Sniffer detects protocol and metadata from initial bytes.
type Sniffer interface {
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
}
// TLSSniffer is the default sniffer implementation.
type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
if err != nil {
return SniffResult{}, err
}
return convertCoreSniffResult(res), nil
}
func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
out := SniffResult{
IsTLS: res.IsTLS,
Buffer: res.Buffer,
}
if res.ClientHello != nil {
out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
}
return out
}
func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
if meta == nil {
return nil
}
return &ClientHelloMeta{
ServerName: meta.ServerName,
LocalAddr: meta.LocalAddr,
RemoteAddr: meta.RemoteAddr,
SupportedProtos: append([]string(nil), meta.SupportedProtos...),
SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
CipherSuites: append([]uint16(nil), meta.CipherSuites...),
}
}
// Conn wraps net.Conn with lazy protocol initialization.
type Conn struct {
net.Conn
once sync.Once
initErr error
closeOnce sync.Once
isTLS bool
tlsConn *tls.Conn
plainConn net.Conn
clientHello *ClientHelloMeta
baseTLSConfig *tls.Config
getConfigForClient GetConfigForClientFunc
getConfigForClientHello GetConfigForClientHelloFunc
allowNonTLS bool
sniffer Sniffer
sniffTimeout time.Duration
maxClientHello int
logger Logger
stats *Stats
skipSniff bool
}
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{
Conn: raw,
plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient,
getConfigForClientHello: cfg.GetConfigForClientHello,
allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger,
stats: stats,
}
}
func (c *Conn) init() {
c.once.Do(func() {
if c.skipSniff {
return
}
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
c.isTLS = false
return
}
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
}
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Time{})
}
if err != nil {
c.initErr = errors.Join(ErrTLSSniffFailed, err)
c.failSniff(err)
return
}
c.isTLS = res.IsTLS
c.clientHello = res.ClientHello
if c.isTLS {
if c.stats != nil {
c.stats.incTLSDetected()
}
tlsCfg, errCfg := c.selectTLSConfig()
if errCfg != nil {
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
c.failTLSConfigSelection(errCfg)
return
}
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
c.tlsConn = tls.Server(rc, tlsCfg)
return
}
if c.stats != nil {
c.stats.incPlainDetected()
}
if !c.allowNonTLS {
c.initErr = ErrNonTLSNotAllowed
c.failPlainRejected()
return
}
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
})
}
func (c *Conn) failAndClose(format string, v ...interface{}) {
if c.logger != nil {
c.logger.Printf("starnet: "+format, v...)
}
_ = c.Close()
}
func (c *Conn) failSniff(err error) {
if c.stats != nil {
c.stats.incSniffFailures()
}
c.failAndClose("tls sniff failed: %v", err)
}
func (c *Conn) failTLSConfigSelection(err error) {
if c.stats != nil {
c.stats.incTLSConfigFailures()
}
c.failAndClose("tls config selection failed: %v", err)
}
func (c *Conn) failPlainRejected() {
if c.stats != nil {
c.stats.incPlainRejected()
}
c.failAndClose("plain tcp rejected")
}
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
var selected *tls.Config
if c.getConfigForClientHello != nil {
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
}
}
if selected == nil && c.getConfigForClient != nil {
cfg, err := c.getConfigForClient(c.serverName())
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
}
}
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
if composed != nil {
return composed, nil
}
return nil, ErrNoTLSConfig
}
// Hostname returns sniffed SNI hostname (if any).
func (c *Conn) Hostname() string {
c.init()
return c.serverName()
}
// ClientHello returns sniffed TLS metadata (if any).
func (c *Conn) ClientHello() *ClientHelloMeta {
c.init()
return c.clientHello.Clone()
}
func (c *Conn) serverName() string {
if c.clientHello == nil {
return ""
}
return c.clientHello.ServerName
}
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
return tlssniffercore.ComposeServerTLSConfig(base, selected)
}
func applyServerTLSOverrides(dst, src *tls.Config) {
tlssniffercore.ApplyServerTLSOverrides(dst, src)
}
func (c *Conn) IsTLS() bool {
c.init()
return c.initErr == nil && c.isTLS
}
func (c *Conn) TLSConn() (*tls.Conn, error) {
c.init()
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS || c.tlsConn == nil {
return nil, ErrNotTLS
}
return c.tlsConn, nil
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.plainConn.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.plainConn.Write(b)
}
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
if c.tlsConn != nil {
err = c.tlsConn.Close()
} else {
err = c.Conn.Close()
}
if c.stats != nil {
c.stats.incClosed()
}
})
return err
}
func (c *Conn) SetDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.plainConn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.plainConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.plainConn.SetWriteDeadline(t)
}
// Listener wraps net.Listener and returns starnet.Conn from Accept.
type Listener struct {
net.Listener
mu sync.RWMutex
cfg ListenerConfig
stats Stats
}
// Listen creates a plain listener config (no TLS detection).
func Listen(network, address string) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
return &Listener{Listener: ln, cfg: cfg}, nil
}
// ListenWithConfig creates a listener with full config.
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenWithListenConfig creates listener using net.ListenConfig.
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenTLS creates TLS listener from cert/key paths.
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = allowNonTLS
cfg.BaseTLSConfig = TLSDefaults()
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
return ListenWithConfig(network, address, cfg)
}
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
out := DefaultListenerConfig()
out.AllowNonTLS = cfg.AllowNonTLS
out.SniffTimeout = cfg.SniffTimeout
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
out.BaseTLSConfig = cfg.BaseTLSConfig
out.GetConfigForClient = cfg.GetConfigForClient
out.GetConfigForClientHello = cfg.GetConfigForClientHello
out.Logger = cfg.Logger
if out.MaxClientHelloBytes <= 0 {
out.MaxClientHelloBytes = 64 * 1024
}
return out
}
// SetConfig atomically replaces listener config for new accepted connections.
func (l *Listener) SetConfig(cfg ListenerConfig) {
l.mu.Lock()
l.cfg = normalizeConfig(cfg)
l.mu.Unlock()
}
// Config returns a copy of current config.
func (l *Listener) Config() ListenerConfig {
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return cfg
}
// Stats returns current counters snapshot.
func (l *Listener) Stats() StatsSnapshot {
return l.stats.Snapshot()
}
func (l *Listener) Accept() (net.Conn, error) {
raw, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.stats.incAccepted()
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return newConn(raw, cfg, &l.stats), nil
}
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
type result struct {
c net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := l.Accept()
ch <- result{c: c, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
return r.c, r.err
}
}
// Dial creates a plain TCP starnet.Conn.
func Dial(network, address string) (*Conn, error) {
raw, err := net.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialWithConfig dials with net.Dialer options.
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
d := net.Dialer{
Timeout: dc.Timeout,
LocalAddr: dc.LocalAddr,
}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialTLSWithConfig creates a TLS client connection wrapper.
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
tc := tls.Client(raw, tlsCfg)
return &Conn{
Conn: raw,
plainConn: raw,
isTLS: true,
tlsConn: tc,
initErr: nil,
allowNonTLS: false,
skipSniff: true,
}, nil
}
// DialTLS creates TLS client conn from cert/key paths.
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := TLSDefaults()
cfg.Certificates = []tls.Certificate{cert}
return DialTLSWithConfig(network, address, cfg, 0)
}
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
if listener == nil {
return nil, ErrNilConn
}
return &Listener{
Listener: listener,
cfg: normalizeConfig(cfg),
}, nil
}
+1210
View File
File diff suppressed because it is too large Load Diff
+55
View File
@@ -0,0 +1,55 @@
package starnet
import "sync/atomic"
// StatsSnapshot is a read-only copy of runtime counters.
type StatsSnapshot struct {
Accepted uint64
TLSDetected uint64
PlainDetected uint64
InitFailures uint64
SniffFailures uint64
TLSConfigFailures uint64
PlainRejected uint64
Closed uint64
}
// Stats provides lock-free counters.
type Stats struct {
accepted uint64
tlsDetected uint64
plainDetected uint64
initFailures uint64
sniffFailures uint64
tlsConfigFailures uint64
plainRejected uint64
closed uint64
}
func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) }
func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
func (s *Stats) incSniffFailures() { atomic.AddUint64(&s.sniffFailures, 1); s.incInitFailures() }
func (s *Stats) incTLSConfigFailures() { atomic.AddUint64(&s.tlsConfigFailures, 1); s.incInitFailures() }
func (s *Stats) incPlainRejected() { atomic.AddUint64(&s.plainRejected, 1); s.incInitFailures() }
// Snapshot returns a stable view of counters.
func (s *Stats) Snapshot() StatsSnapshot {
return StatsSnapshot{
Accepted: atomic.LoadUint64(&s.accepted),
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
PlainDetected: atomic.LoadUint64(&s.plainDetected),
InitFailures: atomic.LoadUint64(&s.initFailures),
SniffFailures: atomic.LoadUint64(&s.sniffFailures),
TLSConfigFailures: atomic.LoadUint64(&s.tlsConfigFailures),
PlainRejected: atomic.LoadUint64(&s.plainRejected),
Closed: atomic.LoadUint64(&s.closed),
}
}
// Logger is a minimal logging abstraction.
type Logger interface {
Printf(format string, v ...interface{})
}
+604
View File
@@ -0,0 +1,604 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http/httptrace"
"sync/atomic"
"time"
)
type traceContextKey struct{}
// TraceHooks defines optional callbacks for network lifecycle events.
// Hooks may be called concurrently.
type TraceHooks struct {
GetConn func(TraceGetConnInfo)
GotConn func(TraceGotConnInfo)
PutIdleConn func(TracePutIdleConnInfo)
DNSStart func(TraceDNSStartInfo)
DNSDone func(TraceDNSDoneInfo)
ConnectStart func(TraceConnectStartInfo)
ConnectDone func(TraceConnectDoneInfo)
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
WroteHeaderField func(TraceWroteHeaderFieldInfo)
WroteHeaders func()
WroteRequest func(TraceWroteRequestInfo)
GotFirstResponseByte func()
RetryAttemptStart func(TraceRetryAttemptStartInfo)
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
RetryBackoff func(TraceRetryBackoffInfo)
}
type TraceGetConnInfo struct {
Addr string
}
type TraceGotConnInfo struct {
Conn net.Conn
Reused bool
WasIdle bool
IdleTime time.Duration
}
type TracePutIdleConnInfo struct {
Err error
}
type TraceDNSStartInfo struct {
Host string
}
type TraceDNSDoneInfo struct {
Addrs []net.IPAddr
Coalesced bool
Err error
}
type TraceConnectStartInfo struct {
Network string
Addr string
}
type TraceConnectDoneInfo struct {
Network string
Addr string
Err error
}
type TraceTLSHandshakeStartInfo struct {
Network string
Addr string
ServerName string
}
type TraceTLSHandshakeDoneInfo struct {
Network string
Addr string
ServerName string
ConnectionState tls.ConnectionState
Err error
}
type TraceWroteHeaderFieldInfo struct {
Key string
Values []string
}
type TraceWroteRequestInfo struct {
Err error
}
type TraceRetryAttemptStartInfo struct {
Attempt int
MaxAttempts int
}
type TraceRetryAttemptDoneInfo struct {
Attempt int
MaxAttempts int
StatusCode int
Err error
WillRetry bool
}
type TraceRetryBackoffInfo struct {
Attempt int
Delay time.Duration
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo
}
func newTraceState(hooks *TraceHooks) *traceState {
if hooks == nil {
return nil
}
return &traceState{hooks: hooks}
}
func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) {
if t == nil {
return
}
t.defaultTLSHandshakeInfo = info
}
func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo {
if t == nil {
return TraceTLSHandshakeStartInfo{}
}
return t.defaultTLSHandshakeInfo
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
}
return context.WithValue(ctx, traceContextKey{}, state)
}
func getTraceState(ctx context.Context) *traceState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(traceContextKey{}).(*traceState)
return state
}
func (t *traceState) needsHTTPTrace() bool {
if t == nil || t.hooks == nil {
return false
}
h := t.hooks
return h.GetConn != nil ||
h.GotConn != nil ||
h.PutIdleConn != nil ||
h.DNSStart != nil ||
h.DNSDone != nil ||
h.ConnectStart != nil ||
h.ConnectDone != nil ||
h.TLSHandshakeStart != nil ||
h.TLSHandshakeDone != nil ||
h.WroteHeaderField != nil ||
h.WroteHeaders != nil ||
h.WroteRequest != nil ||
h.GotFirstResponseByte != nil
}
func (t *traceState) clientTrace() *httptrace.ClientTrace {
if !t.needsHTTPTrace() {
return nil
}
h := t.hooks
trace := &httptrace.ClientTrace{}
if h.GetConn != nil {
trace.GetConn = func(hostPort string) {
h.GetConn(TraceGetConnInfo{Addr: hostPort})
}
}
if h.GotConn != nil {
trace.GotConn = func(info httptrace.GotConnInfo) {
h.GotConn(TraceGotConnInfo{
Conn: info.Conn,
Reused: info.Reused,
WasIdle: info.WasIdle,
IdleTime: info.IdleTime,
})
}
}
if h.PutIdleConn != nil {
trace.PutIdleConn = func(err error) {
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
}
}
if h.DNSStart != nil {
trace.DNSStart = func(info httptrace.DNSStartInfo) {
if t.usesManualDNS() {
return
}
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
}
}
if h.DNSDone != nil {
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
if t.usesManualDNS() {
return
}
h.DNSDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), info.Addrs...),
Coalesced: info.Coalesced,
Err: info.Err,
})
}
}
if h.ConnectStart != nil {
trace.ConnectStart = func(network, addr string) {
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
}
if h.ConnectDone != nil {
trace.ConnectDone = func(network, addr string, err error) {
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
}
if h.TLSHandshakeStart != nil {
trace.TLSHandshakeStart = func() {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo())
}
}
if h.TLSHandshakeDone != nil {
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
if t.usesCustomTLS() {
return
}
info := t.getDefaultTLSHandshakeInfo()
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: info.Network,
Addr: info.Addr,
ServerName: info.ServerName,
ConnectionState: state,
Err: err,
})
}
}
if h.WroteHeaderField != nil {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: append([]string(nil), value...),
})
}
}
if h.WroteHeaders != nil {
trace.WroteHeaders = h.WroteHeaders
}
if h.WroteRequest != nil {
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
}
}
if h.GotFirstResponseByte != nil {
trace.GotFirstResponseByte = h.GotFirstResponseByte
}
return trace
}
func (t *traceState) markCustomTLS() {
if t == nil {
return
}
t.customTLS.Store(1)
}
func (t *traceState) usesCustomTLS() bool {
if t == nil {
return false
}
return t.customTLS.Load() != 0
}
func (t *traceState) beginManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(1)
}
func (t *traceState) endManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(-1)
}
func (t *traceState) usesManualDNS() bool {
if t == nil {
return false
}
return t.manualDNSRefs.Load() > 0
}
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
return
}
t.hooks.TLSHandshakeStart(info)
}
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
return
}
t.hooks.TLSHandshakeDone(info)
}
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
return
}
t.hooks.DNSStart(info)
}
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
return
}
t.hooks.DNSDone(info)
}
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
if hooks == nil || hooks.RetryAttemptStart == nil {
return
}
hooks.RetryAttemptStart(info)
}
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
if hooks == nil || hooks.RetryAttemptDone == nil {
return
}
hooks.RetryAttemptDone(info)
}
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
if hooks == nil || hooks.RetryBackoff == nil {
return
}
hooks.RetryBackoff(info)
}
func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks {
if recorder == nil {
return nil
}
return recorder.Hooks()
}
func composeTraceHooks(first, second *TraceHooks) *TraceHooks {
switch {
case first == nil:
return second
case second == nil:
return first
}
return &TraceHooks{
GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn),
GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn),
PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn),
DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart),
DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone),
ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart),
ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone),
TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart),
TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone),
WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField),
WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders),
WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest),
GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte),
RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart),
RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone),
RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff),
}
}
func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGetConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGotConnInfo) {
first(info)
second(info)
}
}
}
func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TracePutIdleConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteHeaderFieldInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteRequestInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryBackoffInfo) {
first(info)
second(info)
}
}
}
func composeTraceSimpleHook(first, second func()) func() {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func() {
first()
second()
}
}
}
+550
View File
@@ -0,0 +1,550 @@
package starnet
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
)
// TraceSummary 是一次请求执行的 trace 摘要。
type TraceSummary struct {
Method string
URL string
StartedAt time.Time
ResponseAt time.Time
StatusCode int
ResponseProto string
Conn TraceConnSummary
DNS *TraceDNSSummary
DNSEvents []TraceDNSSummary
Connect []TraceConnectSummary
TLS *TraceTLSSummary
RequestWrittenAt time.Time
RequestWriteErr error
FirstResponseByteAt time.Time
}
// TraceConnSummary 是连接复用与套接字信息摘要。
type TraceConnSummary struct {
Addr string
LocalAddr string
RemoteAddr string
Reused bool
WasIdle bool
IdleTime time.Duration
}
// TraceDNSSummary 是 DNS 解析摘要。
type TraceDNSSummary struct {
Host string
Addrs []string
Coalesced bool
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
}
// TraceConnectSummary 是单次建连尝试摘要。
type TraceConnectSummary struct {
Network string
Addr string
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
}
// TraceTLSSummary 是 TLS 握手与连接状态摘要。
type TraceTLSSummary struct {
Network string
Addr string
ServerName string
Version uint16
VersionName string
CipherSuite uint16
CipherSuiteName string
CurveID tls.CurveID
CurveName string
NegotiatedProtocol string
DidResume bool
ECHAccepted bool
VerifiedChains int
StartedAt time.Time
CompletedAt time.Time
Duration time.Duration
Err error
PeerCertificates []TraceCertificateSummary
}
// TraceCertificateSummary 是单张证书的关键信息摘要。
type TraceCertificateSummary struct {
Subject string
Issuer string
DNSNames []string
IPAddresses []string
}
// TraceRecorder 聚合最近一次发布的 trace 摘要。
// 通过 Request/Client 绑定时,starnet 会为每次执行创建私有运行态并在完成后发布摘要;
// 直接使用 Hooks() 时,调用方仍需自行管理 Reset 与生命周期。
type TraceRecorder struct {
mu sync.Mutex
summary TraceSummary
pendingDNS []TraceDNSSummary
pendingConnectStarts map[string][]time.Time
pendingTLSStart time.Time
hooks *TraceHooks
}
// NewTraceRecorder 创建请求级 trace 记录器。
func NewTraceRecorder() *TraceRecorder {
recorder := &TraceRecorder{}
recorder.hooks = &TraceHooks{
GetConn: recorder.onGetConn,
GotConn: recorder.onGotConn,
DNSStart: recorder.onDNSStart,
DNSDone: recorder.onDNSDone,
ConnectStart: recorder.onConnectStart,
ConnectDone: recorder.onConnectDone,
TLSHandshakeStart: recorder.onTLSHandshakeStart,
TLSHandshakeDone: recorder.onTLSHandshakeDone,
WroteRequest: recorder.onWroteRequest,
GotFirstResponseByte: recorder.onGotFirstResponseByte,
}
return recorder
}
// Hooks 返回可挂到请求上的底层 trace hooks。
func (r *TraceRecorder) Hooks() *TraceHooks {
if r == nil {
return nil
}
if r.hooks == nil {
r.hooks = &TraceHooks{
GetConn: r.onGetConn,
GotConn: r.onGotConn,
DNSStart: r.onDNSStart,
DNSDone: r.onDNSDone,
ConnectStart: r.onConnectStart,
ConnectDone: r.onConnectDone,
TLSHandshakeStart: r.onTLSHandshakeStart,
TLSHandshakeDone: r.onTLSHandshakeDone,
WroteRequest: r.onWroteRequest,
GotFirstResponseByte: r.onGotFirstResponseByte,
}
}
return r.hooks
}
// Reset 清空当前摘要和内部状态。
func (r *TraceRecorder) Reset() {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.resetLocked()
}
// Summary 返回当前 trace 摘要的快照。
func (r *TraceRecorder) Summary() TraceSummary {
if r == nil {
return TraceSummary{}
}
r.mu.Lock()
defer r.mu.Unlock()
return cloneTraceSummary(r.summary)
}
func (r *TraceRecorder) forkExecution() *TraceRecorder {
if r == nil {
return nil
}
return NewTraceRecorder()
}
func (r *TraceRecorder) publishSummary(summary TraceSummary) {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.summary = cloneTraceSummary(summary)
}
func (r *TraceRecorder) startRequest() {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.resetLocked()
r.summary.StartedAt = time.Now()
}
func (r *TraceRecorder) observePreparedRequest(req *http.Request) {
if r == nil || req == nil || req.URL == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.ensureStartedLocked(time.Now())
r.summary.Method = req.Method
r.summary.URL = req.URL.String()
}
func (r *TraceRecorder) observeResponse(resp *http.Response) {
if r == nil || resp == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.ResponseAt = now
r.summary.StatusCode = resp.StatusCode
r.summary.ResponseProto = resp.Proto
if resp.TLS != nil {
r.summary.TLS = mergeTraceTLSSummary(r.summary.TLS, *resp.TLS)
}
}
func (r *TraceRecorder) ensureStartedLocked(now time.Time) {
if r.summary.StartedAt.IsZero() {
r.summary.StartedAt = now
}
}
func (r *TraceRecorder) resetLocked() {
r.summary = TraceSummary{}
r.pendingDNS = nil
r.pendingTLSStart = time.Time{}
if len(r.pendingConnectStarts) == 0 {
r.pendingConnectStarts = nil
return
}
for key := range r.pendingConnectStarts {
delete(r.pendingConnectStarts, key)
}
r.pendingConnectStarts = nil
}
func (r *TraceRecorder) onGetConn(info TraceGetConnInfo) {
r.mu.Lock()
defer r.mu.Unlock()
r.ensureStartedLocked(time.Now())
r.summary.Conn.Addr = info.Addr
}
func (r *TraceRecorder) onGotConn(info TraceGotConnInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.Conn.Reused = info.Reused
r.summary.Conn.WasIdle = info.WasIdle
r.summary.Conn.IdleTime = info.IdleTime
if info.Conn != nil {
r.summary.Conn.LocalAddr = traceAddrString(info.Conn.LocalAddr())
r.summary.Conn.RemoteAddr = traceAddrString(info.Conn.RemoteAddr())
}
}
func (r *TraceRecorder) onDNSStart(info TraceDNSStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
dns := TraceDNSSummary{
Host: info.Host,
StartedAt: now,
}
r.pendingDNS = append(r.pendingDNS, dns)
copyDNS := dns
r.summary.DNS = &copyDNS
}
func (r *TraceRecorder) onDNSDone(info TraceDNSDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
dns := TraceDNSSummary{
Host: "",
Addrs: traceIPAddrsToStrings(info.Addrs),
Coalesced: info.Coalesced,
CompletedAt: now,
Err: info.Err,
}
if len(r.pendingDNS) > 0 {
dns.Host = r.pendingDNS[0].Host
dns.StartedAt = r.pendingDNS[0].StartedAt
if len(r.pendingDNS) == 1 {
r.pendingDNS = nil
} else {
r.pendingDNS = append([]TraceDNSSummary(nil), r.pendingDNS[1:]...)
}
} else if r.summary.DNS != nil {
dns.Host = r.summary.DNS.Host
dns.StartedAt = r.summary.DNS.StartedAt
}
if !dns.StartedAt.IsZero() {
dns.Duration = now.Sub(dns.StartedAt)
}
r.summary.DNSEvents = append(r.summary.DNSEvents, dns)
copyDNS := dns
r.summary.DNS = &copyDNS
}
func (r *TraceRecorder) onConnectStart(info TraceConnectStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
if r.pendingConnectStarts == nil {
r.pendingConnectStarts = make(map[string][]time.Time)
}
key := traceConnectKey(info.Network, info.Addr)
r.pendingConnectStarts[key] = append(r.pendingConnectStarts[key], now)
r.summary.Connect = append(r.summary.Connect, TraceConnectSummary{
Network: info.Network,
Addr: info.Addr,
StartedAt: now,
})
}
func (r *TraceRecorder) onConnectDone(info TraceConnectDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
start := time.Time{}
key := traceConnectKey(info.Network, info.Addr)
if starts := r.pendingConnectStarts[key]; len(starts) > 0 {
start = starts[0]
if len(starts) == 1 {
delete(r.pendingConnectStarts, key)
} else {
r.pendingConnectStarts[key] = starts[1:]
}
}
connect := TraceConnectSummary{
Network: info.Network,
Addr: info.Addr,
StartedAt: start,
CompletedAt: now,
Err: info.Err,
}
if !start.IsZero() {
connect.Duration = now.Sub(start)
}
for index := len(r.summary.Connect) - 1; index >= 0; index-- {
item := &r.summary.Connect[index]
if item.Network != info.Network || item.Addr != info.Addr || !item.CompletedAt.IsZero() {
continue
}
item.CompletedAt = now
item.Duration = connect.Duration
item.Err = info.Err
return
}
r.summary.Connect = append(r.summary.Connect, connect)
}
func (r *TraceRecorder) onTLSHandshakeStart(info TraceTLSHandshakeStartInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.pendingTLSStart = now
r.summary.TLS = &TraceTLSSummary{
Network: info.Network,
Addr: info.Addr,
ServerName: info.ServerName,
StartedAt: now,
}
}
func (r *TraceRecorder) onTLSHandshakeDone(info TraceTLSHandshakeDoneInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
var tlsSummary *TraceTLSSummary
if r.summary.TLS != nil {
copied := *r.summary.TLS
tlsSummary = &copied
} else {
tlsSummary = &TraceTLSSummary{}
}
if tlsSummary.Network == "" {
tlsSummary.Network = info.Network
}
if tlsSummary.Addr == "" {
tlsSummary.Addr = info.Addr
}
if tlsSummary.ServerName == "" {
tlsSummary.ServerName = info.ServerName
}
if tlsSummary.StartedAt.IsZero() {
tlsSummary.StartedAt = r.pendingTLSStart
}
tlsSummary.CompletedAt = now
if !tlsSummary.StartedAt.IsZero() {
tlsSummary.Duration = now.Sub(tlsSummary.StartedAt)
}
tlsSummary.Err = info.Err
tlsSummary = mergeTraceTLSSummary(tlsSummary, info.ConnectionState)
if tlsSummary.ServerName == "" {
tlsSummary.ServerName = info.ServerName
}
if tlsSummary.Addr == "" {
tlsSummary.Addr = info.Addr
}
if tlsSummary.Network == "" {
tlsSummary.Network = info.Network
}
r.pendingTLSStart = time.Time{}
r.summary.TLS = tlsSummary
}
func (r *TraceRecorder) onWroteRequest(info TraceWroteRequestInfo) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.RequestWrittenAt = now
r.summary.RequestWriteErr = info.Err
}
func (r *TraceRecorder) onGotFirstResponseByte() {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
r.ensureStartedLocked(now)
r.summary.FirstResponseByteAt = now
}
func traceAddrString(addr net.Addr) string {
if addr == nil {
return ""
}
return addr.String()
}
func traceConnectKey(network, addr string) string {
return network + "\x00" + addr
}
func traceIPAddrsToStrings(addrs []net.IPAddr) []string {
if len(addrs) == 0 {
return nil
}
out := make([]string, 0, len(addrs))
for _, addr := range addrs {
out = append(out, addr.String())
}
return out
}
func mergeTraceTLSSummary(summary *TraceTLSSummary, state tls.ConnectionState) *TraceTLSSummary {
if summary == nil {
summary = &TraceTLSSummary{}
}
if state.Version != 0 {
summary.Version = state.Version
summary.VersionName = tls.VersionName(state.Version)
}
if state.CipherSuite != 0 {
summary.CipherSuite = state.CipherSuite
summary.CipherSuiteName = tls.CipherSuiteName(state.CipherSuite)
}
if state.CurveID != 0 {
summary.CurveID = state.CurveID
summary.CurveName = state.CurveID.String()
}
if state.ServerName != "" {
summary.ServerName = state.ServerName
}
if state.NegotiatedProtocol != "" {
summary.NegotiatedProtocol = state.NegotiatedProtocol
}
summary.DidResume = state.DidResume
summary.ECHAccepted = state.ECHAccepted
summary.VerifiedChains = len(state.VerifiedChains)
if len(state.PeerCertificates) > 0 {
summary.PeerCertificates = make([]TraceCertificateSummary, 0, len(state.PeerCertificates))
for _, cert := range state.PeerCertificates {
certSummary := TraceCertificateSummary{
Subject: cert.Subject.String(),
Issuer: cert.Issuer.String(),
DNSNames: append([]string(nil), cert.DNSNames...),
}
if len(cert.IPAddresses) > 0 {
certSummary.IPAddresses = make([]string, 0, len(cert.IPAddresses))
for _, ip := range cert.IPAddresses {
certSummary.IPAddresses = append(certSummary.IPAddresses, ip.String())
}
}
summary.PeerCertificates = append(summary.PeerCertificates, certSummary)
}
}
return summary
}
func cloneTraceSummary(summary TraceSummary) TraceSummary {
cloned := summary
if summary.DNS != nil {
dns := *summary.DNS
dns.Addrs = append([]string(nil), summary.DNS.Addrs...)
cloned.DNS = &dns
}
if len(summary.DNSEvents) > 0 {
cloned.DNSEvents = make([]TraceDNSSummary, 0, len(summary.DNSEvents))
for _, dns := range summary.DNSEvents {
cloned.DNSEvents = append(cloned.DNSEvents, TraceDNSSummary{
Host: dns.Host,
Addrs: append([]string(nil), dns.Addrs...),
Coalesced: dns.Coalesced,
StartedAt: dns.StartedAt,
CompletedAt: dns.CompletedAt,
Duration: dns.Duration,
Err: dns.Err,
})
}
}
if len(summary.Connect) > 0 {
cloned.Connect = append([]TraceConnectSummary(nil), summary.Connect...)
}
if summary.TLS != nil {
tlsSummary := *summary.TLS
if len(summary.TLS.PeerCertificates) > 0 {
tlsSummary.PeerCertificates = make([]TraceCertificateSummary, 0, len(summary.TLS.PeerCertificates))
for _, cert := range summary.TLS.PeerCertificates {
tlsSummary.PeerCertificates = append(tlsSummary.PeerCertificates, TraceCertificateSummary{
Subject: cert.Subject,
Issuer: cert.Issuer,
DNSNames: append([]string(nil), cert.DNSNames...),
IPAddresses: append([]string(nil), cert.IPAddresses...),
})
}
}
cloned.TLS = &tlsSummary
}
return cloned
}
func cloneTraceSummaryPtr(summary TraceSummary) *TraceSummary {
cloned := cloneTraceSummary(summary)
return &cloned
}
+488
View File
@@ -0,0 +1,488 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"strconv"
"strings"
"testing"
)
func TestTraceRecorderCapturesHTTPSummary(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
req := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.Method != http.MethodGet {
t.Fatalf("method=%q", summary.Method)
}
if summary.URL != server.URL {
t.Fatalf("url=%q", summary.URL)
}
if summary.StatusCode != http.StatusOK {
t.Fatalf("status=%d", summary.StatusCode)
}
if summary.ResponseProto == "" {
t.Fatal("expected response proto")
}
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected request write timestamp")
}
if summary.FirstResponseByteAt.IsZero() {
t.Fatal("expected first response byte timestamp")
}
if summary.Conn.Addr == "" {
t.Fatal("expected get-conn target address")
}
if summary.TLS == nil {
t.Fatal("expected tls summary")
}
tlsSummary := summary.TLS
if tlsSummary.Version == 0 || tlsSummary.VersionName == "" {
t.Fatalf("unexpected tls version summary: %+v", tlsSummary)
}
if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" {
t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary)
}
if tlsSummary.ServerName == "" {
t.Fatal("expected tls server name")
}
if resp.TLS == nil {
t.Fatal("expected response TLS state")
}
if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol {
t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol)
}
if len(tlsSummary.PeerCertificates) == 0 {
t.Fatal("expected certificate summaries")
}
leaf := tlsSummary.PeerCertificates[0]
if leaf.Subject == "" || leaf.Issuer == "" {
t.Fatalf("unexpected leaf certificate summary: %+v", leaf)
}
if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 {
t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf)
}
if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("response trace summary=%+v", got)
}
}
func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
recorder := NewTraceRecorder()
targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(targetURL, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.DNS == nil {
t.Fatal("expected dns summary")
}
if summary.DNS.Host != "trace-summary.example.test" {
t.Fatalf("dns host=%q", summary.DNS.Host)
}
if len(summary.DNS.Addrs) == 0 {
t.Fatal("expected resolved addresses")
}
if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) {
t.Fatalf("dns addrs=%v", summary.DNS.Addrs)
}
if summary.DNS.CompletedAt.IsZero() {
t.Fatal("expected dns completion timestamp")
}
if len(summary.Connect) == 0 {
t.Fatal("expected connect attempts")
}
connect := summary.Connect[0]
if connect.Network == "" || connect.Addr == "" {
t.Fatalf("unexpected connect summary: %+v", connect)
}
if connect.CompletedAt.IsZero() {
t.Fatalf("expected connect completion timestamp: %+v", connect)
}
}
func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if _, err := firstResp.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
if err := firstResp.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
recorder := NewTraceRecorder()
secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer secondResp.Close()
if _, err := secondResp.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if !summary.Conn.Reused {
t.Fatalf("expected reused connection summary, got %+v", summary.Conn)
}
if summary.TLS == nil || summary.TLS.Version == 0 {
t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS)
}
}
func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
wroteRequest := 0
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetTraceHooks(&TraceHooks{
WroteRequest: func(info TraceWroteRequestInfo) {
wroteRequest++
},
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if wroteRequest == 0 {
t.Fatal("expected custom trace hook to run")
}
summary := recorder.Summary()
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected recorder to capture wrote-request event")
}
}
func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) {
recorder := NewTraceRecorder()
hooks := recorder.Hooks()
hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}},
})
hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}},
})
summary := recorder.Summary()
if len(summary.DNSEvents) != 2 {
t.Fatalf("dns events=%d", len(summary.DNSEvents))
}
if summary.DNSEvents[0].Host != "target.example.test" {
t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host)
}
if summary.DNSEvents[1].Host != "proxy.example.test" {
t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host)
}
if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" {
t.Fatalf("last dns summary=%+v", summary.DNS)
}
}
func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "localhost")
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
wantAddr := strings.TrimPrefix(targetURL, "https://")
if startInfo.Network != "tcp" {
t.Fatalf("start network=%q", startInfo.Network)
}
if startInfo.Addr != wantAddr {
t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr)
}
if startInfo.ServerName != "localhost" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" {
t.Fatalf("done info=%+v", doneInfo)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}
func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) {
var captured []string
traceState := newTraceState(&TraceHooks{
WroteHeaderField: func(info TraceWroteHeaderFieldInfo) {
captured = info.Values
},
})
trace := traceState.clientTrace()
values := []string{"a", "b"}
trace.WroteHeaderField("X-Test", values)
values[0] = "mutated"
if len(captured) != 2 {
t.Fatalf("captured=%v", captured)
}
if captured[0] != "a" {
t.Fatalf("captured=%v", captured)
}
}
func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
recorder := NewTraceRecorder()
client := NewClientNoErr(WithTraceRecorder(recorder))
req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet)
resp1, err := req1.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req2 := req1.Clone().SetURL(server.URL + "/two")
resp2, err := req2.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("req1 trace summary=%+v", got)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("req2 trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("resp2 trace summary=%+v", got)
}
if got := recorder.Summary(); got.URL != server.URL+"/two" {
t.Fatalf("shared recorder summary=%+v", got)
}
}
func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
req := NewSimpleRequest(server.URL+"/first", http.MethodGet).
SetTraceRecorder(NewTraceRecorder())
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req.SetURL(server.URL + "/second")
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("resp2 trace summary=%+v", got)
}
}
func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "trace-custom.example.test")
serverAddr := server.Listener.Addr().String()
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr)
}).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if startInfo.Network != "" || startInfo.Addr != "" {
t.Fatalf("start info=%+v", startInfo)
}
if doneInfo.Network != "" || doneInfo.Addr != "" {
t.Fatalf("done info=%+v", doneInfo)
}
if startInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("done server name=%q", doneInfo.ServerName)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}
+324
View File
@@ -0,0 +1,324 @@
package starnet
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
)
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
var mu sync.Mutex
events := map[string]int{}
hooks := &TraceHooks{
GetConn: func(info TraceGetConnInfo) {
mu.Lock()
events["get_conn"]++
mu.Unlock()
},
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
events["got_conn"]++
mu.Unlock()
},
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
events["tls_start"]++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
events["tls_done"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected tls handshake error: %v", info.Err)
}
},
WroteHeaders: func() {
mu.Lock()
events["wrote_headers"]++
mu.Unlock()
},
WroteRequest: func(info TraceWroteRequestInfo) {
mu.Lock()
events["wrote_request"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected write error: %v", info.Err)
}
},
GotFirstResponseByte: func() {
mu.Lock()
events["first_byte"]++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
if events[key] == 0 {
t.Fatalf("expected trace event %q", key)
}
}
}
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
tlsStartCount := 0
tlsDoneCount := 0
var lastInfo TraceTLSHandshakeDoneInfo
hooks := &TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
tlsStartCount++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
tlsDoneCount++
lastInfo = info
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetDialTimeout(1500 * time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if tlsStartCount != 1 {
t.Fatalf("tlsStartCount=%d", tlsStartCount)
}
if tlsDoneCount != 1 {
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
}
if lastInfo.Err != nil {
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
}
if lastInfo.ConnectionState.Version == 0 {
t.Fatal("expected tls connection state")
}
}
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var mu sync.Mutex
dnsStartCount := 0
dnsDoneCount := 0
var dnsStartHost string
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
dnsStartHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(url, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if dnsStartHost != "trace.example.test" {
t.Fatalf("dnsStartHost=%q", dnsStartHost)
}
}
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
connectStartCount := 0
connectDoneCount := 0
hooks := &TraceHooks{
ConnectStart: func(info TraceConnectStartInfo) {
mu.Lock()
connectStartCount++
mu.Unlock()
},
ConnectDone: func(info TraceConnectDoneInfo) {
mu.Lock()
connectDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected connect error: %v", info.Err)
}
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(context.Background(), network, addr)
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if connectStartCount != 1 {
t.Fatalf("connectStartCount=%d", connectStartCount)
}
if connectDoneCount != 1 {
t.Fatalf("connectDoneCount=%d", connectDoneCount)
}
}
func TestTraceHooksRetryEvents(t *testing.T) {
var hits int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
if hits == 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
starts := 0
dones := 0
backoffs := 0
var finalDone TraceRetryAttemptDoneInfo
hooks := &TraceHooks{
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
mu.Lock()
starts++
mu.Unlock()
},
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
mu.Lock()
dones++
finalDone = info
mu.Unlock()
},
RetryBackoff: func(info TraceRetryBackoffInfo) {
mu.Lock()
backoffs++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if starts != 2 {
t.Fatalf("starts=%d", starts)
}
if dones != 2 {
t.Fatalf("dones=%d", dones)
}
if backoffs != 1 {
t.Fatalf("backoffs=%d", backoffs)
}
if finalDone.WillRetry {
t.Fatal("expected final attempt not to retry")
}
if finalDone.StatusCode != http.StatusOK {
t.Fatalf("final status=%d", finalDone.StatusCode)
}
}
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
var gotErr error
hooks := &TraceHooks{
DNSDone: func(info TraceDNSDoneInfo) {
gotErr = info.Err
},
}
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return nil, errors.New("lookup failed")
}).
SetTraceHooks(hooks).
Do()
if err == nil {
t.Fatal("expected request error")
}
if gotErr == nil || gotErr.Error() != "lookup failed" {
t.Fatalf("gotErr=%v", gotErr)
}
}
+416
View File
@@ -0,0 +1,416 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
const dynamicTransportCacheMaxEntries = 64
type dynamicTransportCacheKey struct {
proxyKey string
dialTimeout time.Duration
customIPs string
customDNS string
tlsServerName string
skipVerify bool
}
// Transport 自定义 Transport(支持请求级配置)
type Transport struct {
base *http.Transport
dynamicCache map[dynamicTransportCacheKey]*http.Transport
dynamicCacheOrder []dynamicTransportCacheKey
mu sync.RWMutex
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
t.ensureBase()
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
traceState := getTraceState(req.Context())
execReq := req
execReqCtx := reqCtx
var targetAddrs []string
// 优先级1:完全自定义的 transport
if execReqCtx.Transport != nil {
return execReqCtx.Transport.RoundTrip(execReq)
}
var err error
execReq, execReqCtx, targetAddrs, err = prepareProxyTargetRequest(execReq, execReqCtx, traceState)
if err != nil {
return nil, err
}
// 优先级2:需要动态配置
if needsDynamicTransport(execReqCtx) {
dynamicTransport := t.getDynamicTransport(execReqCtx, traceState)
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(dynamicTransport, execReq, targetAddrs)
}
return dynamicTransport.RoundTrip(execReq)
}
// 优先级3:使用基础 transport
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
if len(targetAddrs) > 0 {
return roundTripResolvedTargets(baseTransport, execReq, targetAddrs)
}
return baseTransport.RoundTrip(execReq)
}
func newBaseHTTPTransport() *http.Transport {
return &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
func (t *Transport) ensureBase() {
if t.base != nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
}
func (t *Transport) ensureBaseLocked() {
if t.base == nil {
t.base = newBaseHTTPTransport()
}
}
func (t *Transport) getDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
if key, ok := newDynamicTransportCacheKey(rc); ok {
return t.getOrCreateCachedDynamicTransport(key, rc)
}
return t.buildDynamicTransport(rc, traceState)
}
func (t *Transport) getOrCreateCachedDynamicTransport(key dynamicTransportCacheKey, rc *RequestContext) *http.Transport {
t.mu.RLock()
if transport := t.dynamicCache[key]; transport != nil {
t.mu.RUnlock()
return transport
}
t.mu.RUnlock()
t.mu.Lock()
defer t.mu.Unlock()
t.ensureBaseLocked()
if transport := t.dynamicCache[key]; transport != nil {
return transport
}
transport := buildDynamicTransportFromBase(t.base, rc, nil)
if t.dynamicCache == nil {
t.dynamicCache = make(map[dynamicTransportCacheKey]*http.Transport)
}
if len(t.dynamicCacheOrder) >= dynamicTransportCacheMaxEntries {
oldestKey := t.dynamicCacheOrder[0]
t.dynamicCacheOrder = t.dynamicCacheOrder[1:]
if oldest := t.dynamicCache[oldestKey]; oldest != nil {
oldest.CloseIdleConnections()
delete(t.dynamicCache, oldestKey)
}
}
t.dynamicCache[key] = transport
t.dynamicCacheOrder = append(t.dynamicCacheOrder, key)
return transport
}
func (t *Transport) resetDynamicTransportCacheLocked() {
for _, key := range t.dynamicCacheOrder {
if transport := t.dynamicCache[key]; transport != nil {
transport.CloseIdleConnections()
}
}
t.dynamicCache = nil
t.dynamicCacheOrder = nil
}
func newDynamicTransportCacheKey(rc *RequestContext) (dynamicTransportCacheKey, bool) {
if rc == nil {
return dynamicTransportCacheKey{}, false
}
if rc.Transport != nil || rc.DialFn != nil || rc.LookupIPFn != nil {
return dynamicTransportCacheKey{}, false
}
if rc.TLSConfig != nil && !rc.TLSConfigCacheable {
return dynamicTransportCacheKey{}, false
}
key := dynamicTransportCacheKey{
proxyKey: normalizeProxyCacheKey(rc.Proxy),
dialTimeout: rc.DialTimeout,
customIPs: serializeTransportCacheList(rc.CustomIP),
customDNS: serializeTransportCacheList(rc.CustomDNS),
tlsServerName: effectiveTLSServerName(rc),
}
if rc.TLSConfig != nil {
key.skipVerify = rc.TLSConfig.InsecureSkipVerify
}
return key, true
}
func normalizeProxyCacheKey(proxy string) string {
if proxy == "" {
return ""
}
proxyURL, err := parseProxyURL(proxy)
if err != nil {
return "\x00invalid:" + proxy
}
return proxyURL.String()
}
func serializeTransportCacheList(values []string) string {
if len(values) == 0 {
return ""
}
var builder strings.Builder
for _, value := range values {
builder.WriteString(value)
builder.WriteByte(0)
}
return builder.String()
}
func effectiveTLSServerName(rc *RequestContext) string {
if rc == nil {
return ""
}
if rc.TLSConfig != nil && rc.TLSConfig.ServerName != "" {
return rc.TLSConfig.ServerName
}
return rc.TLSServerName
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext, traceState *traceState) *http.Transport {
t.ensureBase()
t.mu.RLock()
baseTransport := t.base
t.mu.RUnlock()
return buildDynamicTransportFromBase(baseTransport, rc, traceState)
}
func buildDynamicTransportFromBase(baseTransport *http.Transport, rc *RequestContext, traceState *traceState) *http.Transport {
transport := baseTransport.Clone()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
transport.TLSClientConfig = rc.TLSConfig
}
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := parseProxyURL(rc.Proxy)
if err != nil {
transport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, err
}
} else {
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
if traceState != nil && traceState.hooks != nil && (traceState.hooks.ConnectStart != nil || traceState.hooks.ConnectDone != nil) {
dialFn := rc.DialFn
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if traceState.hooks.ConnectStart != nil {
traceState.hooks.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
conn, err := dialFn(ctx, network, addr)
if traceState.hooks.ConnectDone != nil {
traceState.hooks.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
return conn, err
}
} else {
transport.DialContext = rc.DialFn
}
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
transport.DialTLSContext = defaultDialTLSFunc
}
return transport
}
// Base 获取基础 Transport
func (t *Transport) Base() *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base
}
// SetBase 设置基础 Transport
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.resetDynamicTransportCacheLocked()
t.mu.Unlock()
}
func prepareProxyTargetRequest(req *http.Request, reqCtx *RequestContext, traceState *traceState) (*http.Request, *RequestContext, []string, error) {
if req == nil || req.URL == nil || reqCtx == nil {
return req, reqCtx, nil, nil
}
if reqCtx.Proxy == "" || reqCtx.DialFn != nil {
return req, reqCtx, nil, nil
}
if len(reqCtx.CustomIP) == 0 && len(reqCtx.CustomDNS) == 0 && reqCtx.LookupIPFn == nil {
return req, reqCtx, nil, nil
}
host := req.URL.Hostname()
if host == "" {
return req, reqCtx, nil, nil
}
targetAddrs, err := resolveDialAddresses(req.Context(), reqCtx, host, req.URL.Port(), traceState)
if err != nil {
return nil, nil, nil, err
}
if len(targetAddrs) == 0 {
return req, reqCtx, nil, nil
}
execReqCtx := *reqCtx
execReqCtx.CustomIP = nil
execReqCtx.CustomDNS = nil
execReqCtx.LookupIPFn = nil
if req.URL.Scheme == "https" {
execReqCtx.TLSConfig = withDefaultServerName(execReqCtx.TLSConfig, host)
if execReqCtx.TLSConfigCacheable || reqCtx.TLSConfig == nil {
execReqCtx.TLSConfigCacheable = true
}
}
execCtx := clearTargetResolutionContext(req.Context())
execReq := req.Clone(execCtx)
execReq.Host = req.Host
if len(targetAddrs) == 1 {
execReq.URL.Host = targetAddrs[0]
return execReq, &execReqCtx, nil, nil
}
return execReq, &execReqCtx, targetAddrs, nil
}
func clearTargetResolutionContext(ctx context.Context) context.Context {
if v := ctx.Value(ctxKeyRequestContext); v != nil {
if rc, ok := v.(*RequestContext); ok && rc != nil {
cloned := cloneRequestContext(rc)
cloned.CustomIP = nil
cloned.CustomDNS = nil
cloned.LookupIPFn = nil
ctx = context.WithValue(ctx, ctxKeyRequestContext, cloned)
}
}
ctx = context.WithValue(ctx, ctxKeyCustomIP, []string(nil))
ctx = context.WithValue(ctx, ctxKeyCustomDNS, []string(nil))
ctx = context.WithValue(ctx, ctxKeyLookupIP, (func(context.Context, string) ([]net.IPAddr, error))(nil))
return ctx
}
func withDefaultServerName(cfg *tls.Config, serverName string) *tls.Config {
if serverName == "" {
return cfg
}
if cfg != nil {
if cfg.ServerName != "" {
return cfg
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
return &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
}
}
func roundTripResolvedTargets(rt http.RoundTripper, baseReq *http.Request, targetAddrs []string) (*http.Response, error) {
if rt == nil || baseReq == nil || len(targetAddrs) == 0 {
return rt.RoundTrip(baseReq)
}
if !requestAllowsResolvedTargetFallback(baseReq) && len(targetAddrs) > 1 {
targetAddrs = targetAddrs[:1]
}
var lastErr error
for _, targetAddr := range targetAddrs {
attemptReq, err := cloneRequestForResolvedTarget(baseReq, targetAddr)
if err != nil {
return nil, err
}
resp, err := rt.RoundTrip(attemptReq)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func requestAllowsResolvedTargetFallback(req *http.Request) bool {
if req == nil {
return false
}
if !isIdempotentMethod(req.Method) {
return false
}
if req.Body == nil || req.Body == http.NoBody {
return true
}
return req.GetBody != nil
}
func cloneRequestForResolvedTarget(baseReq *http.Request, targetAddr string) (*http.Request, error) {
req := baseReq.Clone(baseReq.Context())
switch {
case baseReq.Body == nil || baseReq.Body == http.NoBody:
req.Body = baseReq.Body
case baseReq.GetBody != nil:
body, err := baseReq.GetBody()
if err != nil {
return nil, wrapError(err, "clone request body for resolved target")
}
req.Body = body
default:
req.Body = baseReq.Body
}
req.URL.Host = targetAddr
req.Host = baseReq.Host
return req, nil
}
+224
View File
@@ -0,0 +1,224 @@
package starnet
import (
"crypto/tls"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
if first != second {
t.Fatal("expected cached dynamic transport to be reused")
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "first.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "second.test",
}, nil)
if first == second {
t.Fatal("expected distinct tls server names to use different transports")
}
if got := len(transport.dynamicCache); got != 2 {
t.Fatalf("dynamic cache size=%d; want 2", got)
}
}
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
reqCtx := &RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
first := transport.getDynamicTransport(reqCtx, nil)
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected user tls config to bypass dynamic transport cache")
}
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0", got)
}
}
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
first := transport.getDynamicTransport(reqCtx, nil)
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
}
client.SetDefaultSkipTLSVerify(true)
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
}
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected cache reset after default tls change")
}
}
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
client := NewClientNoErr()
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
runRequest := func() bool {
var (
mu sync.Mutex
gotConn bool
reused bool
)
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetTraceHooks(&TraceHooks{
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
gotConn = true
reused = info.Reused
mu.Unlock()
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
mu.Lock()
defer mu.Unlock()
if !gotConn {
t.Fatal("expected GotConn trace event")
}
return reused
}
if runRequest() {
t.Fatal("first request unexpectedly reused a connection")
}
if !runRequest() {
t.Fatal("second request did not reuse cached dynamic transport connection")
}
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if execReq == req {
t.Fatal("expected cloned request for proxy target preparation")
}
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
}
if got := req.URL.Host; got != "proxy-single.test:8443" {
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
}
if len(targetAddrs) != 0 {
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
}
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
t.Fatal("expected synthesized tls config for single target proxy request")
}
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
}
}
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
t.Fatalf("execReq.URL.Host=%q; want original host", got)
}
if len(targetAddrs) != 2 {
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
}
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
}
}
+149
View File
@@ -0,0 +1,149 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
)
// HTTP Content-Type 常量
const (
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
ContentTypeFormData = "multipart/form-data"
ContentTypeJSON = "application/json"
ContentTypeXML = "application/xml"
ContentTypePlain = "text/plain"
ContentTypeHTML = "text/html"
ContentTypeOctetStream = "application/octet-stream"
)
// 默认配置
const (
DefaultDialTimeout = 5 * time.Second
DefaultTimeout = 10 * time.Second
DefaultUserAgent = "starnet"
DefaultFetchRespBody = false
)
// RequestFile 表示要上传的文件
type RequestFile struct {
FormName string // 表单字段名
FileName string // 文件名
FilePath string // 文件路径(如果从文件读取)
FileData io.Reader // 文件数据流
FileSize int64 // 文件大小
FileType string // MIME 类型
}
// UploadProgressFunc 文件上传进度回调函数
type UploadProgressFunc func(filename string, uploaded int64, total int64)
// NetworkConfig 网络配置
type NetworkConfig struct {
Proxy string // 代理地址
DialTimeout time.Duration // 连接超时
Timeout time.Duration // 总超时
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}
// TLSConfig TLS 配置
type TLSConfig struct {
Config *tls.Config // TLS 配置
SkipVerify bool // 跳过证书验证
ServerName string // 显式 TLS ServerName/SNI 覆盖
}
// DNSConfig DNS 配置
type DNSConfig struct {
CustomIP []string // 直接指定 IP(最高优先级)
CustomDNS []string // 自定义 DNS 服务器
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
}
type bodyMode uint8
const (
bodyModeUnset bodyMode = iota
bodyModeBytes
bodyModeReader
bodyModeForm
bodyModeMultipart
)
// BodyConfig 请求体配置
type BodyConfig struct {
Mode bodyMode // 当前 body 来源模式
Bytes []byte // 原始字节
Reader io.Reader // 数据流
FormData map[string][]string // 表单数据
Files []RequestFile // 文件列表
}
// RequestConfig 请求配置(内部使用)
type RequestConfig struct {
Network NetworkConfig
TLS TLSConfig
DNS DNSConfig
Body BodyConfig
Headers http.Header
Cookies []*http.Cookie
Queries map[string][]string
// 其他配置
BasicAuth [2]string // Basic 认证
Host string // 显式 Host 头覆盖
ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
UploadProgress UploadProgressFunc // 上传进度回调
// Transport 配置
CustomTransport bool // 是否使用自定义 Transport
Transport *http.Transport // 自定义 Transport
}
// Clone 克隆配置
func (c *RequestConfig) Clone() *RequestConfig {
return &RequestConfig{
Network: NetworkConfig{
Proxy: c.Network.Proxy,
DialTimeout: c.Network.DialTimeout,
Timeout: c.Network.Timeout,
DialFunc: c.Network.DialFunc,
},
TLS: TLSConfig{
Config: cloneTLSConfig(c.TLS.Config),
SkipVerify: c.TLS.SkipVerify,
ServerName: c.TLS.ServerName,
},
DNS: DNSConfig{
CustomIP: cloneStringSlice(c.DNS.CustomIP),
CustomDNS: cloneStringSlice(c.DNS.CustomDNS),
LookupFunc: c.DNS.LookupFunc,
},
Body: BodyConfig{
Mode: c.Body.Mode,
Bytes: cloneBytes(c.Body.Bytes),
Reader: c.Body.Reader, // Reader 不可克隆
FormData: cloneStringMapSlice(c.Body.FormData),
Files: cloneFiles(c.Body.Files),
},
Headers: cloneHeader(c.Headers),
Cookies: cloneCookies(c.Cookies),
Queries: cloneStringMapSlice(c.Queries),
BasicAuth: c.BasicAuth,
Host: c.Host,
ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength,
MaxRespBodyBytes: c.MaxRespBodyBytes,
UploadProgress: c.UploadProgress,
CustomTransport: c.CustomTransport,
Transport: c.Transport, // Transport 共享
}
}
// RequestOpt 请求选项函数
type RequestOpt func(*Request) error
+223
View File
@@ -0,0 +1,223 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net/http"
"net/url"
"strings"
)
// validMethod 验证 HTTP 方法是否有效
func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// isNotToken 检查字符是否不是 token 字符
func isNotToken(r rune) bool {
return !isTokenRune(r)
}
// isTokenRune 检查字符是否是 token 字符
func isTokenRune(r rune) bool {
i := int(r)
return i < 127 && isTokenTable[i]
}
// isTokenTable token 字符表
var isTokenTable = [127]bool{
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
}
// hasPort 检查地址是否包含端口
func hasPort(s string) bool {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}
// removeEmptyPort 移除空端口
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}
// UrlEncode URL 编码
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
// UrlEncodeRaw URL 编码(空格编码为 %20)
func UrlEncodeRaw(str string) string {
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
}
// UrlDecode URL 解码
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
// BuildQuery 构建查询字符串
func BuildQuery(data map[string]string) string {
query := url.Values{}
for k, v := range data {
query.Add(k, v)
}
return query.Encode()
}
// BuildPostForm 构建 POST 表单数据
func BuildPostForm(data map[string]string) []byte {
return []byte(BuildQuery(data))
}
// cloneHeader 克隆 Header
func cloneHeader(h http.Header) http.Header {
if h == nil {
return make(http.Header)
}
newHeader := make(http.Header, len(h))
for k, v := range h {
newHeader[k] = append([]string(nil), v...)
}
return newHeader
}
// cloneCookies 克隆 Cookies
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
if cookies == nil {
return nil
}
newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies {
newCookies[i] = cloneCookie(c)
}
return newCookies
}
func cloneCookie(cookie *http.Cookie) *http.Cookie {
if cookie == nil {
return nil
}
return &http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
Path: cookie.Path,
Domain: cookie.Domain,
Expires: cookie.Expires,
RawExpires: cookie.RawExpires,
MaxAge: cookie.MaxAge,
Secure: cookie.Secure,
HttpOnly: cookie.HttpOnly,
SameSite: cookie.SameSite,
Raw: cookie.Raw,
Unparsed: append([]string(nil), cookie.Unparsed...),
}
}
// cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil {
return make(map[string][]string)
}
newMap := make(map[string][]string, len(m))
for k, v := range m {
newMap[k] = append([]string(nil), v...)
}
return newMap
}
// cloneBytes 克隆字节切片
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
newBytes := make([]byte, len(b))
copy(newBytes, b)
return newBytes
}
// cloneStringSlice 克隆字符串切片
func cloneStringSlice(s []string) []string {
if s == nil {
return nil
}
newSlice := make([]string, len(s))
copy(newSlice, s)
return newSlice
}
// cloneFiles 克隆文件列表
func cloneFiles(files []RequestFile) []RequestFile {
if files == nil {
return nil
}
newFiles := make([]RequestFile, len(files))
copy(newFiles, files)
return newFiles
}
// cloneTLSConfig 克隆 TLS 配置
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return nil
}
return cfg.Clone()
}
// copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if ctx == nil {
ctx = context.Background()
}
var written int64
buf := make([]byte, 32*1024) // 32KB buffer
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, err := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if progress != nil {
// 同步调用进度回调(不使用 goroutine)
progress(filename, written, total)
}
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
if progress != nil {
// 最后一次进度回调
progress(filename, written, total)
}
return written, nil
}
return written, err
}
}
}
+284
View File
@@ -0,0 +1,284 @@
package starnet
import (
"net/http"
"testing"
"time"
)
func TestUrlEncodeRaw(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "basic string with space",
input: "hello world",
expected: "hello%20world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "chinese characters",
input: "你好世界",
expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncodeRaw(tt.input)
if result != tt.expected {
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlEncode(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "space encoded as plus",
input: "hello world",
expected: "hello+world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncode(tt.input)
if result != tt.expected {
t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlDecode(t *testing.T) {
tests := []struct {
name string
input string
expected string
expectErr bool
}{
{
name: "basic decode",
input: "hello%20world",
expected: "hello world",
expectErr: false,
},
{
name: "plus to space",
input: "hello+world",
expected: "hello world",
expectErr: false,
},
{
name: "special characters",
input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60",
expected: "hello world!@#$%^&*()_+-=~`",
expectErr: false,
},
{
name: "invalid encoding",
input: "%zz",
expected: "",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := UrlDecode(tt.input)
if tt.expectErr {
if err == nil {
t.Errorf("UrlDecode(%q) expected error, got nil", tt.input)
}
} else {
if err != nil {
t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err)
}
if result != tt.expected {
t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected)
}
}
})
}
}
func TestBuildQuery(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected string
}{
{
name: "single parameter",
input: map[string]string{
"key": "value",
},
expected: "key=value",
},
{
name: "empty map",
input: map[string]string{},
expected: "",
},
{
name: "nil map",
input: nil,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildQuery(tt.input)
if result != tt.expected {
t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestBuildPostForm(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected []byte
}{
{
name: "basic form",
input: map[string]string{
"key1": "value1",
},
expected: []byte("key1=value1"),
},
{
name: "empty map",
input: map[string]string{},
expected: []byte(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildPostForm(tt.input)
if string(result) != string(tt.expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestValidMethod(t *testing.T) {
tests := []struct {
name string
method string
expected bool
}{
{"GET", "GET", true},
{"POST", "POST", true},
{"PUT", "PUT", true},
{"DELETE", "DELETE", true},
{"PATCH", "PATCH", true},
{"OPTIONS", "OPTIONS", true},
{"HEAD", "HEAD", true},
{"TRACE", "TRACE", true},
{"CONNECT", "CONNECT", true},
{"invalid with space", "GET POST", false},
{"invalid with special char", "GET<>", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validMethod(tt.method)
if result != tt.expected {
t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected)
}
})
}
}
func TestCloneCookies_FullFields(t *testing.T) {
expire := time.Now().Add(2 * time.Hour)
src := []*http.Cookie{
{
Name: "sid",
Value: "abc123",
Path: "/",
Domain: "example.com",
Expires: expire,
RawExpires: expire.UTC().Format(time.RFC1123),
MaxAge: 3600,
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Raw: "sid=abc123; Path=/; HttpOnly",
Unparsed: []string{"Priority=High", "Partitioned"},
},
}
got := cloneCookies(src)
if got == nil || len(got) != 1 {
t.Fatalf("cloneCookies() len=%v; want 1", len(got))
}
// 指针应不同(不是浅拷贝)
if got[0] == src[0] {
t.Fatal("cookie pointer should be different (deep copy expected)")
}
// 字段值应一致
s := src[0]
g := got[0]
if g.Name != s.Name ||
g.Value != s.Value ||
g.Path != s.Path ||
g.Domain != s.Domain ||
!g.Expires.Equal(s.Expires) ||
g.RawExpires != s.RawExpires ||
g.MaxAge != s.MaxAge ||
g.Secure != s.Secure ||
g.HttpOnly != s.HttpOnly ||
g.SameSite != s.SameSite ||
g.Raw != s.Raw {
t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s)
}
// Unparsed 内容一致
if len(g.Unparsed) != len(s.Unparsed) {
t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed))
}
for i := range s.Unparsed {
if g.Unparsed[i] != s.Unparsed[i] {
t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i])
}
}
// 验证 Unparsed 是深拷贝(修改源不影响目标)
src[0].Unparsed[0] = "Modified=Yes"
if got[0].Unparsed[0] == "Modified=Yes" {
t.Fatal("Unparsed should be deep-copied, but was affected by source mutation")
}
}