Compare commits

..

25 Commits

Author SHA1 Message Date
9ac9b65bc5
fix(starnet): 收紧 TLS ClientHello 嗅探并补齐边界测试
- 用轻量 ClientHello 解析替代假握手式 TLS 嗅探
  - 保留截断和 max-bytes 场景下的 TLS 分类与缓冲回放能力
  - 拒绝首个 record 完整但并非 ClientHello 的伪 TLS 流量
  - 为动态 TLS 配置选择透出更完整的 ClientHello 元数据
  - 拆分 TLS 初始化失败统计为 sniff/config/plain rejected
  - 补充正常、分片、截断、限长、伪 TLS 等回归测试
2026-03-27 12:05:23 +08:00
b5bd7595a1
1. 优化ping功能
2. 新增重试机制
3. 优化错误处理逻辑
2026-03-19 16:42:45 +08:00
4568e17f06
fix: 修复核心bug并完善API
- 修复NewRequest系列函数不返回opt错误的问题
- 修复prepare()幂等性问题,支持请求重试
- 修复defaultDialTLSFunc的ServerName解析错误
- 修复Client.Clone()并发安全问题
- 补齐Client.Trace/Connect方法
- 新增Request.HTTPClient/Client方法
- 增强NewSimpleRequest错误处理的健壮性
2026-03-10 19:55:37 +08:00
1bb30514ec
bug fix:tls自定义时,没有设置servername的问题 2026-03-08 21:38:45 +08:00
50aef48d49
rewrite program 2026-03-08 20:19:40 +08:00
0e2f91eee2
fix:使用Client时,设置的参数不生效 2025-10-14 10:08:53 +08:00
b90c59d6e7
修改版本号 2025-08-21 21:40:29 +08:00
4e154cc17b
update benchmark 2025-08-21 21:37:21 +08:00
67b0025f9c
更新content-length的默认处理方式 2025-08-21 19:17:19 +08:00
c4fa62536a
为client新增部分函数 2025-08-21 15:32:19 +08:00
260ceb90ed
重构http Client部分 2025-08-21 15:02:02 +08:00
d260181adf
update 2025-08-15 15:07:51 +08:00
e3b7369e12
bug fix:nil pointer error 2025-08-13 10:16:08 +08:00
4e17fee681
bug fix 2025-07-14 18:38:31 +08:00
a8eed30db5
add http client control 2025-07-14 18:23:14 +08:00
c1eaf43058 update 2025-06-17 12:36:57 +08:00
9f5aca124d update 2025-06-17 12:09:12 +08:00
54958724e7 bug fix 2025-06-13 17:16:38 +08:00
7a17672149 update tls sniffer 2025-06-12 16:50:47 +08:00
44b807d3d1 update 2025-06-06 15:43:38 +08:00
0d847462b3 bug fix:nil pointer 2025-04-28 13:19:45 +08:00
deed4207ea bug fix 2024-08-30 23:44:49 +08:00
f6363fed07 move starqueue from starnet to stario 2024-08-18 17:18:52 +08:00
1de78f2f06 rewrite curl.go 2024-08-08 22:03:10 +08:00
d0122a9771 update go.mod & update que.go 2024-03-10 14:04:48 +08:00
54 changed files with 12388 additions and 882 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
.idea
.sentrux/
agent_readme.md
target.md

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2026 starnet contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

106
README.md Normal file
View File

@ -0,0 +1,106 @@
# starnet
`starnet` is a Go network toolkit focused on practical HTTP request control, TLS sniff utilities, and ICMP ping capabilities.
## Highlights
- Request-level timeout by context (without mutating shared `http.Client` timeout)
- Fine-grained network controls: custom DNS/IP, dial timeout, proxy, TLS config
- Built-in retry with replay safety checks and configurable backoff/jitter/statuses
- Response body safety guard via max body bytes limit
- Error classification helpers (`ClassifyError`, `IsTimeout`, `IsDNS`, `IsTLS`, `IsProxy`, `IsCanceled`)
- TLS sniffer listener/dialer utilities for mixed TLS/plain traffic scenarios
- ICMP ping with IPv4/IPv6 target handling and option-based probing API
## Main Features
### HTTP Client and Request
- Fluent APIs with both `WithXxx` options and `SetXxx` chain methods
- Methods: `Get/Post/Put/Delete/Head/Patch/Options/Trace/Connect`
- Request body helpers: JSON, form data, multipart file upload, stream body
- Header/cookie/query helpers with defensive copy on key setters
- Request cloning for safe reuse in concurrent or variant calls
### Timeout and Retry
- Request timeout is applied by context deadline, not global client timeout
- Retry supports:
- max attempts
- backoff factor/base/max
- jitter
- retry status whitelist
- idempotent-only guard
- custom retry-on-error callback
- Retry keeps original request pointer in final response for consistency
### Response Handling
- `Bytes/String/JSON/Reader` helpers
- optional auto-fetch mode
- configurable max response body bytes to prevent oversized reads
### Ping Module
- `Ping`, `PingWithContext`, `Pingable`, and compatibility helper `IsIpPingable`
- `PingOptions` for count/timeout/interval/deadline/address preference/source IP/payload size
- explicit error semantics for permission/protocol/timeout/resolve failures
## Install
```bash
go get b612.me/starnet
```
## Quick Example
```go
package main
import (
"fmt"
"net/http"
"time"
"b612.me/starnet"
)
func main() {
resp, err := starnet.Get(
"https://example.com",
starnet.WithTimeout(2*time.Second),
starnet.WithRetry(2,
starnet.WithRetryBackoff(100*time.Millisecond, 1*time.Second, 2),
starnet.WithRetryJitter(0.1),
),
starnet.WithMaxRespBodyBytes(1<<20),
)
if err != nil {
fmt.Println("request failed:", starnet.ClassifyError(err), err)
return
}
defer resp.Close()
fmt.Println("status:", resp.StatusCode)
_, _ = resp.Body().Bytes()
ok, pingErr := starnet.Pingable("example.com", &starnet.PingOptions{
Count: 2,
Timeout: 2 * time.Second,
})
fmt.Println("pingable:", ok, pingErr == nil)
_ = http.MethodGet
}
```
## Stability Notes
- Raw ICMP ping may require elevated privileges on some systems.
- Integration tests that rely on external network are environment-dependent.
## License
This project is licensed under the Apache License 2.0.
See [LICENSE](./LICENSE).

1657
addon_test.go Normal file

File diff suppressed because it is too large Load Diff

197
benchmark_test.go Normal file
View File

@ -0,0 +1,197 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func BenchmarkGetRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkGetRequestWithHeaders(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL,
WithHeader("X-Custom", "value"),
WithUserAgent("BenchmarkAgent"))
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkPostRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
testData := []byte("test data for benchmark")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithBody(testData))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
resp.Body().String()
resp.Close()
}
}
func BenchmarkJSONRequest(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type TestData struct {
Name string `json:"name"`
Value int `json:"value"`
}
data := TestData{Name: "test", Value: 123}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Post(server.URL, WithJSON(data))
if err != nil {
b.Fatalf("Post() error: %v", err)
}
var result map[string]string
resp.Body().JSON(&result)
resp.Close()
}
}
func BenchmarkConcurrentRequests(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().String()
resp.Close()
}
})
}
func BenchmarkRequestClone(b *testing.B) {
req := NewSimpleRequest("https://example.com", "GET").
SetHeader("X-Custom", "value").
AddQuery("key", "value")
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = req.Clone()
}
}
func BenchmarkClientCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewClientNoErr()
}
}
func BenchmarkRequestCreation(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = NewSimpleRequest("https://example.com", "GET")
}
}
func BenchmarkResponseBodyRead(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test response data"))
}))
defer server.Close()
// Pre-fetch response
resp, _ := Get(server.URL, WithAutoFetch(true))
defer resp.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = resp.Body().String()
}
}
func BenchmarkDifferentResponseSizes(b *testing.B) {
sizes := []int{100, 1024, 10240, 102400} // 100B, 1KB, 10KB, 100KB
for _, size := range sizes {
responseData := make([]byte, size)
for i := 0; i < size; i++ {
responseData[i] = 'A'
}
b.Run(fmt.Sprintf("Size_%d", size), func(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(responseData)
}))
defer server.Close()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resp, err := Get(server.URL)
if err != nil {
b.Fatalf("Get() error: %v", err)
}
resp.Body().Bytes()
resp.Close()
}
})
}
}

145
body_test.go Normal file
View File

@ -0,0 +1,145 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestRequestBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := []byte("test data")
req := NewSimpleRequest(server.URL, "POST").SetBody(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().Bytes()
if !bytes.Equal(body, testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyString(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test string data"
req := NewSimpleRequest(server.URL, "POST").SetBodyString(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestBodyReader(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
w.Write(body)
}))
defer server.Close()
testData := "test reader data"
reader := strings.NewReader(testData)
req := NewSimpleRequest(server.URL, "POST").SetBodyReader(reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestRequestJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Type") != ContentTypeJSON {
t.Errorf("Content-Type = %v; want %v", r.Header.Get("Content-Type"), ContentTypeJSON)
}
var data map[string]string
json.NewDecoder(r.Body).Decode(&data)
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
testData := map[string]string{
"name": "John",
"email": "john@example.com",
}
req := NewSimpleRequest(server.URL, "POST").SetJSON(testData)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != testData["name"] {
t.Errorf("name = %v; want %v", result["name"], testData["name"])
}
}
func TestRequestFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
data := make(map[string]string)
for k, v := range r.Form {
if len(v) > 0 {
data[k] = v[0]
}
}
json.NewEncoder(w).Encode(data)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("email", "john@example.com")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if result["name"] != "John" {
t.Errorf("name = %v; want John", result["name"])
}
if result["email"] != "john@example.com" {
t.Errorf("email = %v; want john@example.com", result["email"])
}
}

345
client.go Normal file
View File

@ -0,0 +1,345 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
)
// Client HTTP 客户端封装
type Client struct {
client *http.Client
opts []RequestOpt
mu sync.RWMutex
}
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
//Timeout: DefaultTimeout,
}
// 应用选项(如果有)
if len(opts) > 0 {
// 创建临时请求以应用选项
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
if err != nil {
return nil, wrapError(err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
if req.config.Network.Timeout > 0 {
httpClient.Timeout = req.config.Network.Timeout
}
*/
// 如果有自定义 Transport
if req.config.CustomTransport && req.config.Transport != nil {
httpClient.Transport = &Transport{base: req.config.Transport}
}
}
return &Client{
client: httpClient,
opts: opts,
}, nil
}
// NewClientNoErr 创建新的 Client忽略错误
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
client = &Client{
client: &http.Client{},
opts: opts,
}
}
return client
}
// NewClientFromHTTP 从 http.Client 创建 Client
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
if httpClient == nil {
return nil, ErrNilClient
}
// 确保 Transport 是我们的自定义类型
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
// 已经是我们的类型
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
// 包装标准 Transport
httpClient.Transport = &Transport{
base: t,
}
default:
return nil, fmt.Errorf("unsupported transport type: %T", t)
}
}
return &Client{
client: httpClient,
}, nil
}
// HTTPClient 获取底层 http.Client
func (c *Client) HTTPClient() *http.Client {
return c.client
}
// RequestOptions 获取默认选项(返回副本)
func (c *Client) RequestOptions() []RequestOpt {
c.mu.RLock()
defer c.mu.RUnlock()
opts := make([]RequestOpt, len(c.opts))
copy(opts, c.opts)
return opts
}
// SetOptions 设置默认选项
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = opts
c.mu.Unlock()
return c
}
// AddOptions 追加默认选项
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = append(c.opts, opts...)
c.mu.Unlock()
return c
}
// Clone 克隆 Client深拷贝
func (c *Client) Clone() *Client {
c.mu.RLock()
defer c.mu.RUnlock()
// 克隆 Transport
var transport http.RoundTripper
if c.client.Transport != nil {
switch t := c.client.Transport.(type) {
case *Transport:
transport = &Transport{
base: t.base.Clone(),
}
case *http.Transport:
transport = t.Clone()
default:
transport = c.client.Transport
}
}
return &Client{
client: &http.Client{
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
},
opts: append([]RequestOpt(nil), c.opts...),
}
}
// SetDefaultTLSConfig 设置默认 TLS 配置
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if tlsConfig != nil {
transport.base.TLSClientConfig = tlsConfig.Clone()
} else {
transport.base.TLSClientConfig = nil
}
transport.mu.Unlock()
}
return c
}
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
} else {
transport.base.TLSClientConfig = transport.base.TLSClientConfig.Clone()
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.mu.Unlock()
}
return c
}
// DisableRedirect 禁用重定向
func (c *Client) DisableRedirect() *Client {
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return c
}
// EnableRedirect 启用重定向
func (c *Client) EnableRedirect() *Client {
c.client.CheckRedirect = nil
return c
}
// NewRequest 创建新请求
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
return c.NewRequestWithContext(context.Background(), url, method, opts...)
}
// NewRequestWithContext 创建新请求(带 context
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
// 合并 Client 级别和请求级别的选项
c.mu.RLock()
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
c.mu.RUnlock()
req, err := newRequest(ctx, url, method, allOpts...)
if err != nil {
return nil, err
}
req.client = c
req.httpClient = c.client
return req, nil
}
// Get 发送 GET 请求
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodGet, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Post 发送 POST 请求
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPost, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Put 发送 PUT 请求
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPut, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Delete 发送 DELETE 请求
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodDelete, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Head 发送 HEAD 请求
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodHead, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Patch 发送 PATCH 请求
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPatch, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Options 发送 OPTIONS 请求
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodOptions, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
if err != nil {
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
client: c,
httpClient: c.client,
autoFetch: DefaultFetchRespBody,
}
}
return req
}
// Trace 发送 TRACE 请求
func (c *Client) Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求
func (c *Client) Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}

223
client_test.go Normal file
View File

@ -0,0 +1,223 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
if client == nil {
t.Fatal("NewClient() returned nil")
}
}
func TestNewClientNoErr(t *testing.T) {
client := NewClientNoErr()
if client == nil {
t.Fatal("NewClientNoErr() returned nil")
}
}
func TestNewClientFromHTTP(t *testing.T) {
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
client, err := NewClientFromHTTP(httpClient)
if err != nil {
t.Fatalf("NewClientFromHTTP() error: %v", err)
}
if client == nil {
t.Fatal("NewClientFromHTTP() returned nil")
}
// Test with nil client
_, err = NewClientFromHTTP(nil)
if err == nil {
t.Error("NewClientFromHTTP(nil) should return error")
}
}
func TestClientOptions(t *testing.T) {
client := NewClientNoErr()
// Set options
client.SetOptions(WithTimeout(5 * time.Second))
opts := client.RequestOptions()
if len(opts) != 1 {
t.Errorf("RequestOptions() length = %v; want 1", len(opts))
}
// Add options
client.AddOptions(WithUserAgent("TestAgent"))
opts = client.RequestOptions()
if len(opts) != 2 {
t.Errorf("RequestOptions() length = %v; want 2", len(opts))
}
}
func TestClientClone(t *testing.T) {
client := NewClientNoErr(WithTimeout(5 * time.Second))
cloned := client.Clone()
if cloned == nil {
t.Fatal("Clone() returned nil")
}
// 修改克隆的 client
cloned.SetOptions(WithTimeout(10 * time.Second))
origOpts := client.RequestOptions()
clonedOpts := cloned.RequestOptions()
// 原 client 应该还是 1 个选项
if len(origOpts) != 1 {
t.Errorf("Original client options = %v; want 1", len(origOpts))
}
// 克隆的 client 应该是 1 个选项(被 SetOptions 覆盖)
if len(clonedOpts) != 1 {
t.Errorf("Cloned client options = %v; want 1", len(clonedOpts))
}
}
func TestClientHTTPMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
client := NewClientNoErr()
tests := []struct {
name string
method func(string, ...RequestOpt) (*Response, error)
want string
}{
{"GET", client.Get, "GET"},
{"POST", client.Post, "POST"},
{"PUT", client.Put, "PUT"},
{"DELETE", client.Delete, "DELETE"},
{"PATCH", client.Patch, "PATCH"},
{"HEAD", client.Head, "HEAD"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.method(server.URL)
if err != nil {
t.Fatalf("%s() error: %v", tt.name, err)
}
defer resp.Close()
if tt.want != "HEAD" {
body, _ := resp.Body().String()
if body != tt.want {
t.Errorf("Body = %v; want %v", body, tt.want)
}
}
})
}
}
func TestClientRedirect(t *testing.T) {
redirectCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount < 2 {
redirectCount++
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("final"))
}))
defer server.Close()
// Test with redirect enabled (default)
client := NewClientNoErr()
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
resp.Close()
if redirectCount != 2 {
t.Errorf("Redirect count = %v; want 2", redirectCount)
}
// Test with redirect disabled
redirectCount = 0
client.DisableRedirect()
resp2, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusFound {
t.Errorf("StatusCode = %v; want %v", resp2.StatusCode, http.StatusFound)
}
}
func TestClientTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail with self-signed cert)
client := NewClientNoErr()
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected TLS error with self-signed cert, got nil")
}
// With skip verify
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() with skip verify error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientNewSimpleRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
req := client.NewSimpleRequest(server.URL, "GET", WithHeader("X-Test", "v"))
if req == nil {
t.Fatal("NewSimpleRequest returned nil")
}
if req.Err() != nil {
t.Fatalf("NewSimpleRequest err: %v", req.Err())
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}

111
concurrent_test.go Normal file
View File

@ -0,0 +1,111 @@
package starnet
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestConcurrentRequests(t *testing.T) {
var counter int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&counter, 1)
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
client := NewClientNoErr()
concurrency := 100
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
wg.Wait()
if atomic.LoadInt64(&counter) != int64(concurrency) {
t.Errorf("counter = %v; want %v", counter, concurrency)
}
}
func TestConcurrentClientModification(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
var wg sync.WaitGroup
wg.Add(200)
// 100 goroutines reading
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
resp, err := client.Get(server.URL)
if err != nil {
t.Errorf("Get() error: %v", err)
return
}
resp.Close()
}()
}
// 100 goroutines modifying options
for i := 0; i < 100; i++ {
go func(i int) {
defer wg.Done()
if i%2 == 0 {
client.AddOptions(WithTimeout(5 * time.Second))
} else {
_ = client.RequestOptions()
}
}(i)
}
wg.Wait()
}
func TestConcurrentRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
baseReq := NewSimpleRequest(server.URL, "GET").SetHeader("X-Base", "value")
var wg sync.WaitGroup
wg.Add(50)
for i := 0; i < 50; i++ {
go func(i int) {
defer wg.Done()
cloned := baseReq.Clone()
// 修复:使用有效的 header 值
cloned.SetHeader("X-Index", fmt.Sprintf("%d", i))
resp, err := cloned.Do()
if err != nil {
t.Errorf("Do() error: %v", err)
return
}
resp.Close()
}(i)
}
wg.Wait()
}

145
context.go Normal file
View File

@ -0,0 +1,145 @@
package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// contextKey 私有的 context key 类型(防止冲突)
type contextKey int
const (
ctxKeyTransport contextKey = iota
ctxKeyTLSConfig
ctxKeyProxy
ctxKeyCustomIP
ctxKeyCustomDNS
ctxKeyDialTimeout
ctxKeyTimeout
ctxKeyLookupIP
ctxKeyDialFunc
)
// RequestContext 从 context 中提取的请求配置
type RequestContext struct {
Transport *http.Transport
TLSConfig *tls.Config
Proxy string
CustomIP []string
CustomDNS []string
DialTimeout time.Duration
Timeout time.Duration
LookupIPFn func(ctx context.Context, host string) ([]net.IPAddr, error)
DialFn func(ctx context.Context, network, addr string) (net.Conn, error)
}
// getRequestContext 从 context 中提取请求配置
func getRequestContext(ctx context.Context) *RequestContext {
rc := &RequestContext{}
if v := ctx.Value(ctxKeyTransport); v != nil {
rc.Transport, _ = v.(*http.Transport)
}
if v := ctx.Value(ctxKeyTLSConfig); v != nil {
rc.TLSConfig, _ = v.(*tls.Config)
}
if v := ctx.Value(ctxKeyProxy); v != nil {
rc.Proxy, _ = v.(string)
}
if v := ctx.Value(ctxKeyCustomIP); v != nil {
rc.CustomIP, _ = v.([]string)
}
if v := ctx.Value(ctxKeyCustomDNS); v != nil {
rc.CustomDNS, _ = v.([]string)
}
if v := ctx.Value(ctxKeyDialTimeout); v != nil {
rc.DialTimeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyTimeout); v != nil {
rc.Timeout, _ = v.(time.Duration)
}
if v := ctx.Value(ctxKeyLookupIP); v != nil {
rc.LookupIPFn, _ = v.(func(context.Context, string) ([]net.IPAddr, error))
}
if v := ctx.Value(ctxKeyDialFunc); v != nil {
rc.DialFn, _ = v.(func(context.Context, string, string) (net.Conn, error))
}
return rc
}
// needsDynamicTransport 判断是否需要动态 Transport
func needsDynamicTransport(rc *RequestContext) bool {
return rc.Transport != nil ||
rc.TLSConfig != nil ||
rc.Proxy != "" ||
rc.DialFn != nil ||
(rc.DialTimeout > 0 && rc.DialTimeout != DefaultDialTimeout) ||
len(rc.CustomIP) > 0 ||
len(rc.CustomDNS) > 0 ||
rc.LookupIPFn != nil
}
// injectRequestConfig 将请求配置注入到 context
func injectRequestConfig(ctx context.Context, config *RequestConfig) context.Context {
execCtx := ctx
// 处理 TLS 配置
var tlsConfig *tls.Config
if config.TLS.Config != nil {
tlsConfig = config.TLS.Config.Clone()
if config.TLS.SkipVerify {
tlsConfig.InsecureSkipVerify = true
}
} else if config.TLS.SkipVerify {
tlsConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
if tlsConfig != nil {
execCtx = context.WithValue(execCtx, ctxKeyTLSConfig, tlsConfig)
}
// 注入代理
if config.Network.Proxy != "" {
execCtx = context.WithValue(execCtx, ctxKeyProxy, config.Network.Proxy)
}
// 注入自定义 IP
if len(config.DNS.CustomIP) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomIP, config.DNS.CustomIP)
}
// 注入自定义 DNS
if len(config.DNS.CustomDNS) > 0 {
execCtx = context.WithValue(execCtx, ctxKeyCustomDNS, config.DNS.CustomDNS)
}
// 总是注入 DialTimeout与原始代码一致
if config.Network.DialTimeout > 0 {
execCtx = context.WithValue(execCtx, ctxKeyDialTimeout, config.Network.DialTimeout)
}
// 注入 DNS 解析函数
if config.DNS.LookupFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyLookupIP, config.DNS.LookupFunc)
}
// 注入 Dial 函数
if config.Network.DialFunc != nil {
execCtx = context.WithValue(execCtx, ctxKeyDialFunc, config.Network.DialFunc)
}
// 注入自定义 Transport
if config.CustomTransport && config.Transport != nil {
execCtx = context.WithValue(execCtx, ctxKeyTransport, config.Transport)
}
return execCtx
}

463
curl.go
View File

@ -1,463 +0,0 @@
package starnet
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
"b612.me/stario"
)
const (
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
HEADER_FORM_DATA = `multipart/form-data`
HEADER_JSON = `application/json`
HEADER_PLAIN = `text/plain`
)
type RequestFile struct {
UploadFile string
UploadForm map[string]string
UploadName string
}
type Request struct {
Url string
RespURL string
Method string
RecvData []byte
RecvContentLength int64
RecvIo io.Writer
RespHeader http.Header
RespCookies []*http.Cookie
RespHttpCode int
Location *url.URL
CircleBuffer *stario.StarBuffer
respReader io.ReadCloser
respOrigin *http.Response
reqOrigin *http.Request
RequestOpts
}
type RequestOpts struct {
RequestFile
PostBuffer io.Reader
Process func(float64)
Proxy string
Timeout time.Duration
DialTimeout time.Duration
ReqHeader http.Header
ReqCookies []*http.Cookie
WriteRecvData bool
SkipTLSVerify bool
CustomTransport *http.Transport
Queries map[string]string
DisableRedirect bool
TlsConfig *tls.Config
}
type RequestOpt func(opt *RequestOpts)
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) {
opt.DialTimeout = timeout
}
}
func WithTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) {
opt.Timeout = timeout
}
}
func WithHeader(key, val string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set(key, val)
}
}
func WithTlsConfig(tlscfg *tls.Config) RequestOpt {
return func(opt *RequestOpts) {
opt.TlsConfig = tlscfg
}
}
func WithHeaderMap(header map[string]string) RequestOpt {
return func(opt *RequestOpts) {
for key, val := range header {
opt.ReqHeader.Set(key, val)
}
}
}
func WithHeaderAdd(key, val string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Add(key, val)
}
}
func WithReader(r io.Reader) RequestOpt {
return func(opt *RequestOpts) {
opt.PostBuffer = r
}
}
func WithFetchRespBody(fetch bool) RequestOpt {
return func(opt *RequestOpts) {
opt.WriteRecvData = fetch
}
}
func WithCookies(ck []*http.Cookie) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqCookies = ck
}
}
func WithCookie(key, val, path string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
}
}
func WithCookieMap(header map[string]string, path string) RequestOpt {
return func(opt *RequestOpts) {
for key, val := range header {
opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path})
}
}
}
func WithQueries(queries map[string]string) RequestOpt {
return func(opt *RequestOpts) {
opt.Queries = queries
}
}
func WithProxy(proxy string) RequestOpt {
return func(opt *RequestOpts) {
opt.Proxy = proxy
}
}
func WithProcess(fn func(float64)) RequestOpt {
return func(opt *RequestOpts) {
opt.Process = fn
}
}
func WithContentType(ct string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set("Content-Type", ct)
}
}
func WithUserAgent(ua string) RequestOpt {
return func(opt *RequestOpts) {
opt.ReqHeader.Set("User-Agent", ua)
}
}
func WithCustomTransport(hs *http.Transport) RequestOpt {
return func(opt *RequestOpts) {
opt.CustomTransport = hs
}
}
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(opt *RequestOpts) {
opt.SkipTLSVerify = skip
}
}
func WithDisableRedirect(disable bool) RequestOpt {
return func(opt *RequestOpts) {
opt.DisableRedirect = disable
}
}
func NewRequests(url string, rawdata []byte, method string, opts ...RequestOpt) Request {
req := Request{
RequestOpts: RequestOpts{
Timeout: 30 * time.Second,
DialTimeout: 15 * time.Second,
WriteRecvData: true,
},
Url: url,
Method: method,
}
if rawdata != nil {
req.PostBuffer = bytes.NewBuffer(rawdata)
}
req.ReqHeader = make(http.Header)
if strings.ToUpper(method) == "POST" {
req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE)
}
req.ReqHeader.Set("User-Agent", "B612 / 1.1.0")
for _, v := range opts {
v(&req.RequestOpts)
}
if req.CustomTransport == nil {
req.CustomTransport = &http.Transport{}
}
if req.SkipTLSVerify {
if req.CustomTransport.TLSClientConfig == nil {
req.CustomTransport.TLSClientConfig = &tls.Config{}
}
req.CustomTransport.TLSClientConfig.InsecureSkipVerify = true
}
if req.TlsConfig != nil {
req.CustomTransport.TLSClientConfig = req.TlsConfig
}
req.CustomTransport.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
c, err := net.DialTimeout(netw, addr, req.DialTimeout)
if err != nil {
return nil, err
}
if req.Timeout != 0 {
c.SetDeadline(time.Now().Add(req.Timeout))
}
return c, nil
}
return req
}
func (curl *Request) ResetReqHeader() {
curl.ReqHeader = make(http.Header)
}
func (curl *Request) ResetReqCookies() {
curl.ReqCookies = []*http.Cookie{}
}
func (curl *Request) AddSimpleCookie(key, value string) {
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"})
}
func (curl *Request) AddCookie(key, value, path string) {
curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: path})
}
func randomBoundary() string {
var buf [30]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
func Curl(curl Request) (resps Request, err error) {
var fpsrc *os.File
if curl.RequestFile.UploadFile != "" {
fpsrc, err = os.Open(curl.UploadFile)
if err != nil {
return
}
defer fpsrc.Close()
boundary := randomBoundary()
boundarybytes := []byte("\r\n--" + boundary + "\r\n")
endbytes := []byte("\r\n--" + boundary + "--\r\n")
fpstat, _ := fpsrc.Stat()
filebig := float64(fpstat.Size())
sum, n := 0, 0
fpdst := stario.NewStarBuffer(1048576)
if curl.UploadForm != nil {
for k, v := range curl.UploadForm {
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\";\r\nContent-Type: x-www-form-urlencoded \r\n\r\n", k)
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
fpdst.Write([]byte(v))
}
}
header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n", curl.UploadName, fpstat.Name())
fpdst.Write(boundarybytes)
fpdst.Write([]byte(header))
go func() {
for {
bufs := make([]byte, 393213)
n, err = fpsrc.Read(bufs)
if err != nil {
if err == io.EOF {
if n != 0 {
fpdst.Write(bufs[0:n])
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
}
break
}
return
}
sum += n
if curl.Process != nil {
go curl.Process(float64(sum+n) / filebig * 100)
}
fpdst.Write(bufs[0:n])
}
fpdst.Write(endbytes)
fpdst.Write(nil)
}()
curl.CircleBuffer = fpdst
curl.ReqHeader.Set("Content-Type", "multipart/form-data;boundary="+boundary)
}
req, resp, err := netcurl(curl)
if err != nil {
return Request{}, err
}
if resp.Request != nil && resp.Request.URL != nil {
curl.RespURL = resp.Request.URL.String()
}
curl.reqOrigin = req
curl.respOrigin = resp
curl.Location, _ = resp.Location()
curl.RespHttpCode = resp.StatusCode
curl.RespHeader = resp.Header
curl.RespCookies = resp.Cookies()
curl.RecvContentLength = resp.ContentLength
readFunc := func(reader io.ReadCloser, writer io.Writer) error {
lengthall := resp.ContentLength
defer reader.Close()
var lengthsum int
buf := make([]byte, 65535)
for {
n, err := reader.Read(buf)
if n != 0 {
_, err := writer.Write(buf[:n])
lengthsum += n
if curl.Process != nil {
go curl.Process(float64(lengthsum) / float64(lengthall) * 100.00)
}
if err != nil {
return err
}
}
if err != nil && err != io.EOF {
return err
} else if err == io.EOF {
return nil
}
}
}
if curl.WriteRecvData {
buf := bytes.NewBuffer([]byte{})
err = readFunc(resp.Body, buf)
if err != nil {
return
}
curl.RecvData = buf.Bytes()
} else {
curl.respReader = resp.Body
}
if curl.RecvIo != nil {
if curl.WriteRecvData {
_, err = curl.RecvIo.Write(curl.RecvData)
} else {
err = readFunc(resp.Body, curl.RecvIo)
if err != nil {
return
}
}
}
return curl, err
}
// RespBodyReader Only works when WriteRecvData set to false
func (curl *Request) RespBodyReader() io.ReadCloser {
return curl.respReader
}
func netcurl(curl Request) (*http.Request, *http.Response, error) {
var req *http.Request
var err error
if curl.Method == "" {
return nil, nil, errors.New("Error Method Not Entered")
}
if curl.PostBuffer != nil {
req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer)
} else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 {
req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer)
} else {
req, err = http.NewRequest(curl.Method, curl.Url, nil)
}
if curl.Queries != nil {
sid := req.URL.Query()
for k, v := range curl.Queries {
sid.Add(k, v)
}
req.URL.RawQuery = sid.Encode()
}
if err != nil {
return nil, nil, err
}
req.Header = curl.ReqHeader
if len(curl.ReqCookies) != 0 {
for _, v := range curl.ReqCookies {
req.AddCookie(v)
}
}
if curl.Proxy != "" {
purl, err := url.Parse(curl.Proxy)
if err != nil {
return nil, nil, err
}
curl.CustomTransport.Proxy = http.ProxyURL(purl)
}
client := &http.Client{
Transport: curl.CustomTransport,
}
if curl.DisableRedirect {
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
resp, err := client.Do(req)
return req, resp, err
}
func UrlEncodeRaw(str string) string {
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
return strs
}
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
func BuildQuery(queryData map[string]string) string {
query := url.Values{}
for k, v := range queryData {
query.Add(k, v)
}
return query.Encode()
}
func BuildPostForm(queryMap map[string]string) []byte {
query := url.Values{}
for k, v := range queryMap {
query.Add(k, v)
}
return []byte(query.Encode())
}
func (r Request) Resopnse() *http.Response {
return r.respOrigin
}
func (r Request) Request() *http.Request {
return r.reqOrigin
}

147
defaults.go Normal file
View File

@ -0,0 +1,147 @@
package starnet
import (
"net/http"
"sync"
"time"
)
var (
defaultClient *Client
defaultHTTPClient *http.Client
defaultClientOnce sync.Once
defaultHTTPOnce sync.Once
defaultMu sync.RWMutex
)
// DefaultClient 获取默认 Client单例
func DefaultClient() *Client {
defaultMu.RLock()
if defaultClient != nil {
c := defaultClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultClientOnce.Do(func() {
c := NewClientNoErr()
defaultMu.Lock()
defaultClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultClient
defaultMu.RUnlock()
return c
}
// DefaultHTTPClient 获取默认 http.Client单例
func DefaultHTTPClient() *http.Client {
defaultMu.RLock()
if defaultHTTPClient != nil {
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
defaultMu.RUnlock()
defaultHTTPOnce.Do(func() {
c := &http.Client{
Transport: &Transport{
base: &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
},
Timeout: 0, // 由请求级控制超时
}
defaultMu.Lock()
defaultHTTPClient = c
defaultMu.Unlock()
})
defaultMu.RLock()
c := defaultHTTPClient
defaultMu.RUnlock()
return c
}
// SetDefaultClient 设置默认 Client
func SetDefaultClient(client *Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultClient = client
// 标记 once 已完成,避免后续 DefaultClient() 再次初始化覆盖
defaultClientOnce.Do(func() {})
}
// SetDefaultHTTPClient 设置默认 http.Client
func SetDefaultHTTPClient(client *http.Client) {
defaultMu.Lock()
defer defaultMu.Unlock()
defaultHTTPClient = client
// 标记 once 已完成,避免后续 DefaultHTTPClient() 再次初始化覆盖
defaultHTTPOnce.Do(func() {})
}
// Get 发送 GET 请求(使用默认 Client
func Get(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Get(url, opts...)
}
// Post 发送 POST 请求(使用默认 Client
func Post(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Post(url, opts...)
}
// Put 发送 PUT 请求(使用默认 Client
func Put(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Put(url, opts...)
}
// Delete 发送 DELETE 请求(使用默认 Client
func Delete(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Delete(url, opts...)
}
// Head 发送 HEAD 请求(使用默认 Client
func Head(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Head(url, opts...)
}
// Patch 发送 PATCH 请求(使用默认 Client
func Patch(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Patch(url, opts...)
}
// Options 发送 OPTIONS 请求(使用默认 Client
func Options(url string, opts ...RequestOpt) (*Response, error) {
return DefaultClient().Options(url, opts...)
}
// Trace 发送 TRACE 请求(使用默认 Client
func Trace(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodTrace, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Connect 发送 CONNECT 请求(使用默认 Client
func Connect(url string, opts ...RequestOpt) (*Response, error) {
req, err := DefaultClient().NewRequest(url, http.MethodConnect, opts...)
if err != nil {
return nil, err
}
return req.Do()
}

59
defensive_copy_test.go Normal file
View File

@ -0,0 +1,59 @@
package starnet
import (
"net/http"
"testing"
)
func TestWithRawRequestNil(t *testing.T) {
_, err := NewRequest("http://example.com", "GET", WithRawRequest(nil))
if err == nil {
t.Fatal("expected error when WithRawRequest(nil)")
}
}
func TestSetHeadersDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
headers := http.Header{
"X-Test": []string{"v1"},
}
req.SetHeaders(headers)
headers.Set("X-Test", "v2")
if got := req.GetHeader("X-Test"); got != "v1" {
t.Fatalf("header mutated by external map change: got=%q want=%q", got, "v1")
}
}
func TestSetQueriesDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET")
queries := map[string][]string{
"k": []string{"v1"},
}
req.SetQueries(queries)
queries["k"][0] = "v2"
queries["k"] = append(queries["k"], "v3")
got := req.config.Queries["k"]
if len(got) != 1 || got[0] != "v1" {
t.Fatalf("queries mutated by external map change: got=%v want=[v1]", got)
}
}
func TestSetFormDataDefensiveCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", "POST")
form := map[string][]string{
"name": []string{"alice"},
}
req.SetFormData(form)
form["name"][0] = "bob"
form["name"] = append(form["name"], "carol")
got := req.config.Body.FormData["name"]
if len(got) != 1 || got[0] != "alice" {
t.Fatalf("form data mutated by external map change: got=%v want=[alice]", got)
}
}

160
dialer.go Normal file
View File

@ -0,0 +1,160 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"
)
// defaultDialFunc 默认 Dial 函数(支持自定义 IP 和 DNS
func defaultDialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 提取配置
reqCtx := getRequestContext(ctx)
dialTimeout := reqCtx.DialTimeout
if dialTimeout == 0 {
dialTimeout = DefaultDialTimeout
}
// 解析地址
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, wrapError(err, "split host port")
}
// 获取 IP 地址列表
var addrs []string
// 优先级1直接指定的 IP
if len(reqCtx.CustomIP) > 0 {
for _, ip := range reqCtx.CustomIP {
addrs = append(addrs, net.JoinHostPort(ip, port))
}
} else {
// 优先级2DNS 解析
var ipAddrs []net.IPAddr
// 使用自定义解析函数
if reqCtx.LookupIPFn != nil {
ipAddrs, err = reqCtx.LookupIPFn(ctx, host)
} else if len(reqCtx.CustomDNS) > 0 {
// 使用自定义 DNS 服务器
dialer := &net.Dialer{Timeout: dialTimeout}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var lastErr error
for _, dnsServer := range reqCtx.CustomDNS {
conn, err := dialer.DialContext(ctx, "udp", net.JoinHostPort(dnsServer, "53"))
if err != nil {
lastErr = err
continue
}
return conn, nil
}
return nil, lastErr
},
}
ipAddrs, err = resolver.LookupIPAddr(ctx, host)
} else {
// 使用默认解析器
ipAddrs, err = net.DefaultResolver.LookupIPAddr(ctx, host)
}
if err != nil {
return nil, wrapError(err, "lookup ip")
}
for _, ipAddr := range ipAddrs {
addrs = append(addrs, net.JoinHostPort(ipAddr.String(), port))
}
}
// 尝试连接所有地址
dialer := &net.Dialer{Timeout: dialTimeout}
var lastErr error
for _, addr := range addrs {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
lastErr = err
continue
}
return conn, nil
}
if lastErr != nil {
return nil, wrapError(lastErr, "dial all addresses failed")
}
return nil, fmt.Errorf("no addresses to dial")
}
// defaultDialTLSFunc 默认 TLS Dial 函数
func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 先建立 TCP 连接
conn, err := defaultDialFunc(ctx, network, addr)
if err != nil {
return nil, err
}
// 提取 TLS 配置
reqCtx := getRequestContext(ctx)
tlsConfig := reqCtx.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
// ← 新增:如果 ServerName 为空且没有 InsecureSkipVerify自动设置
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(addr)
if err != nil {
if idx := strings.LastIndex(addr, ":"); idx > 0 {
host = addr[:idx]
} else {
host = addr
}
}
tlsConfig = tlsConfig.Clone() // 避免修改原 config
tlsConfig.ServerName = host
}
// 执行 TLS 握手
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, wrapError(err, "tls handshake")
}
return tlsConn, nil
}
/*
// defaultProxyFunc 默认代理函数
func defaultProxyFunc(req *http.Request) (*url.URL, error) {
if req == nil {
return nil, fmt.Errorf("request is nil")
}
reqCtx := getRequestContext(req.Context())
if reqCtx.Proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(reqCtx.Proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
return proxyURL, nil
}
*/

103
dns_test.go Normal file
View File

@ -0,0 +1,103 @@
package starnet
import (
"context"
"net"
"testing"
)
func TestRequestCustomIP(t *testing.T) {
customIPs := []string{"1.2.3.4", "5.6.7.8"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP(customIPs)
if len(req.config.DNS.CustomIP) != 2 {
t.Errorf("CustomIP length = %v; want 2", len(req.config.DNS.CustomIP))
}
for i, ip := range req.config.DNS.CustomIP {
if ip != customIPs[i] {
t.Errorf("CustomIP[%d] = %v; want %v", i, ip, customIPs[i])
}
}
}
func TestRequestCustomIPInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"invalid-ip"})
if req.Err() == nil {
t.Error("Expected error for invalid IP, got nil")
}
}
func TestRequestCustomDNS(t *testing.T) {
dnsServers := []string{"8.8.8.8", "1.1.1.1"}
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS(dnsServers)
if len(req.config.DNS.CustomDNS) != 2 {
t.Errorf("CustomDNS length = %v; want 2", len(req.config.DNS.CustomDNS))
}
}
func TestRequestCustomDNSInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET").
SetCustomDNS([]string{"invalid-dns"})
if req.Err() == nil {
t.Error("Expected error for invalid DNS, got nil")
}
}
func TestRequestLookupFunc(t *testing.T) {
called := false
lookupFunc := func(ctx context.Context, host string) ([]net.IPAddr, error) {
called = true
return []net.IPAddr{
{IP: net.ParseIP("1.2.3.4")},
}, nil
}
req := NewSimpleRequest("http://example.com", "GET").
SetLookupFunc(lookupFunc)
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc not set")
}
// Call the function to verify it works
ips, err := req.config.DNS.LookupFunc(context.Background(), "example.com")
if err != nil {
t.Errorf("LookupFunc error: %v", err)
}
if !called {
t.Error("LookupFunc was not called")
}
if len(ips) != 1 {
t.Errorf("IPs length = %v; want 1", len(ips))
}
}
func TestDNSPriority(t *testing.T) {
// CustomIP should have highest priority
req := NewSimpleRequest("http://example.com", "GET").
SetCustomIP([]string{"1.2.3.4"}).
SetCustomDNS([]string{"8.8.8.8"}).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP("5.6.7.8")}}, nil
})
// CustomIP should be set
if len(req.config.DNS.CustomIP) == 0 {
t.Error("CustomIP should be set")
}
// Others should also be set (but CustomIP takes priority in actual use)
if len(req.config.DNS.CustomDNS) == 0 {
t.Error("CustomDNS should be set")
}
if req.config.DNS.LookupFunc == nil {
t.Error("LookupFunc should be set")
}
}

257
errors.go Normal file
View File

@ -0,0 +1,257 @@
package starnet
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"strings"
)
var (
// ErrInvalidMethod 无效的 HTTP 方法
ErrInvalidMethod = errors.New("starnet: invalid HTTP method")
// ErrInvalidURL 无效的 URL
ErrInvalidURL = errors.New("starnet: invalid URL")
// ErrInvalidIP 无效的 IP 地址
ErrInvalidIP = errors.New("starnet: invalid IP address")
// ErrInvalidDNS 无效的 DNS 服务器
ErrInvalidDNS = errors.New("starnet: invalid DNS server")
// ErrNilClient HTTP Client 为 nil
ErrNilClient = errors.New("starnet: http client is nil")
// ErrNilReader Reader 为 nil
ErrNilReader = errors.New("starnet: reader is nil")
// ErrFileNotFound 文件不存在
ErrFileNotFound = errors.New("starnet: file not found")
// ErrRequestNotPrepared 请求未准备好
ErrRequestNotPrepared = errors.New("starnet: request not prepared")
// ErrBodyAlreadyConsumed Body 已被消费
ErrBodyAlreadyConsumed = errors.New("starnet: response body already consumed")
// ErrRespBodyTooLarge 响应体超过允许上限
ErrRespBodyTooLarge = errors.New("starnet: response body too large")
// ErrPingInvalidTimeout ping 超时参数无效
ErrPingInvalidTimeout = errors.New("starnet: invalid ping timeout")
// ErrPingPermissionDenied ping 需要更高权限raw socket
ErrPingPermissionDenied = errors.New("starnet: ping permission denied")
// ErrPingProtocolUnsupported ping 协议/地址族不受当前平台支持
ErrPingProtocolUnsupported = errors.New("starnet: ping protocol unsupported")
// ErrPingNoResolvedTarget ping 目标无法解析为可用地址
ErrPingNoResolvedTarget = errors.New("starnet: ping target not resolved")
)
// wrapError 包装错误,添加上下文信息
func wrapError(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
msg := fmt.Sprintf(format, args...)
return fmt.Errorf("%s: %w", msg, err)
}
var (
// ErrNilConn indicates a nil net.Conn argument.
ErrNilConn = errors.New("starnet: nil connection")
// ErrTLSSniffFailed indicates TLS sniffing/parsing failed before handshake setup.
ErrTLSSniffFailed = errors.New("starnet: tls sniff failed")
// ErrTLSConfigSelectionFailed indicates dynamic TLS config selection failed.
ErrTLSConfigSelectionFailed = errors.New("starnet: tls config selection failed")
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
// ErrNotTLS indicates caller asked for TLS-only object but conn is plain TCP.
ErrNotTLS = errors.New("starnet: connection is not TLS")
// ErrNoTLSConfig indicates TLS was detected but no usable TLS config is available.
ErrNoTLSConfig = errors.New("starnet: no TLS config available")
)
// ErrorKind is a normalized high-level category for request errors.
type ErrorKind string
const (
ErrorKindNone ErrorKind = "none"
ErrorKindCanceled ErrorKind = "canceled"
ErrorKindTimeout ErrorKind = "timeout"
ErrorKindDNS ErrorKind = "dns"
ErrorKindTLS ErrorKind = "tls"
ErrorKindProxy ErrorKind = "proxy"
ErrorKindOther ErrorKind = "other"
)
// IsCanceled reports whether err is a cancellation-related error.
func IsCanceled(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "context canceled") ||
strings.Contains(msg, "operation was canceled") ||
strings.Contains(msg, "request canceled")
}
// ClassifyError maps low-level errors to a stable category for business handling.
func ClassifyError(err error) ErrorKind {
if err == nil {
return ErrorKindNone
}
if IsCanceled(err) {
return ErrorKindCanceled
}
if IsProxy(err) {
return ErrorKindProxy
}
if IsDNS(err) {
return ErrorKindDNS
}
if IsTLS(err) {
return ErrorKindTLS
}
if IsTimeout(err) {
return ErrorKindTimeout
}
return ErrorKindOther
}
// IsTimeout reports whether err is a timeout-related error.
func IsTimeout(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) && uerr.Timeout() {
return true
}
var nerr net.Error
if errors.As(err, &nerr) && nerr.Timeout() {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "timeout") || strings.Contains(msg, "deadline exceeded")
}
// IsDNS reports whether err is a DNS resolution related error.
func IsDNS(err error) bool {
if err == nil {
return false
}
var derr *net.DNSError
if errors.As(err, &derr) {
return true
}
msg := strings.ToLower(err.Error())
if strings.Contains(msg, "no such host") ||
strings.Contains(msg, "server misbehaving") ||
strings.Contains(msg, "temporary failure in name resolution") {
return true
}
return strings.Contains(msg, "lookup ") &&
(strings.Contains(msg, "dns") || strings.Contains(msg, "i/o timeout"))
}
// IsTLS reports whether err is TLS/Certificate related.
func IsTLS(err error) bool {
if err == nil {
return false
}
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) ||
errors.Is(err, ErrTLSSniffFailed) || errors.Is(err, ErrTLSConfigSelectionFailed) {
return true
}
var recErr tls.RecordHeaderError
if errors.As(err, &recErr) {
return true
}
var uaErr x509.UnknownAuthorityError
if errors.As(err, &uaErr) {
return true
}
var hnErr x509.HostnameError
if errors.As(err, &hnErr) {
return true
}
var certErr x509.CertificateInvalidError
if errors.As(err, &certErr) {
return true
}
var rootsErr x509.SystemRootsError
if errors.As(err, &rootsErr) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:")
}
// IsProxy reports whether err is proxy related.
func IsProxy(err error) bool {
if err == nil {
return false
}
if isProxyMessage(strings.ToLower(err.Error())) {
return true
}
var uerr *url.Error
if errors.As(err, &uerr) {
if strings.Contains(strings.ToLower(uerr.Op), "proxy") {
return true
}
if uerr.Err != nil && isProxyMessage(strings.ToLower(uerr.Err.Error())) {
return true
}
}
var opErr *net.OpError
if errors.As(err, &opErr) && strings.Contains(strings.ToLower(opErr.Op), "proxy") {
return true
}
return false
}
func isProxyMessage(msg string) bool {
return strings.Contains(msg, "proxyconnect") ||
strings.Contains(msg, "proxy error") ||
strings.Contains(msg, "proxy authentication required") ||
strings.Contains(msg, "proxy: unknown scheme") ||
strings.Contains(msg, "socks connect") ||
strings.Contains(msg, "socks5")
}

116
errors_classify_test.go Normal file
View File

@ -0,0 +1,116 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"testing"
)
type timeoutErr struct{}
func (timeoutErr) Error() string { return "i/o timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
func TestIsTimeout(t *testing.T) {
if !IsTimeout(context.DeadlineExceeded) {
t.Fatal("context deadline should be timeout")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: timeoutErr{},
}
if !IsTimeout(uerr) {
t.Fatal("url timeout error should be timeout")
}
if !IsTimeout(fmt.Errorf("wrapped: %w", uerr)) {
t.Fatal("wrapped timeout should be timeout")
}
if IsTimeout(errors.New("plain error")) {
t.Fatal("plain error must not be timeout")
}
}
func TestIsDNS(t *testing.T) {
dnsErr := &net.DNSError{
Err: "no such host",
Name: "example.invalid",
IsNotFound: true,
}
if !IsDNS(dnsErr) {
t.Fatal("dns error should be dns")
}
if !IsDNS(fmt.Errorf("wrapped: %w", dnsErr)) {
t.Fatal("wrapped dns error should be dns")
}
if !IsDNS(errors.New("lookup example.invalid: no such host")) {
t.Fatal("lookup no such host should be dns")
}
if IsDNS(errors.New("connection reset by peer")) {
t.Fatal("non dns error should not be dns")
}
}
func TestIsTLS(t *testing.T) {
tlsErr := tls.RecordHeaderError{Msg: "first record does not look like a TLS handshake"}
if !IsTLS(tlsErr) {
t.Fatal("tls record header error should be tls")
}
if !IsTLS(fmt.Errorf("wrapped: %w", tlsErr)) {
t.Fatal("wrapped tls error should be tls")
}
if !IsTLS(errors.New("x509: certificate signed by unknown authority")) {
t.Fatal("x509 error text should be tls")
}
if !IsTLS(ErrNotTLS) {
t.Fatal("ErrNotTLS should be tls related")
}
if IsTLS(errors.New("plain error")) {
t.Fatal("plain error should not be tls")
}
}
func TestIsProxy(t *testing.T) {
raw := errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: connect: connection refused")
if !IsProxy(raw) {
t.Fatal("proxyconnect error should be proxy")
}
uerr := &url.Error{
Op: "Get",
URL: "http://example.com",
Err: raw,
}
if !IsProxy(uerr) {
t.Fatal("wrapped proxy error should be proxy")
}
opErr := &net.OpError{
Op: "proxyconnect",
Net: "tcp",
Err: errors.New("connect failed"),
}
if !IsProxy(opErr) {
t.Fatal("net.OpError proxyconnect should be proxy")
}
if IsProxy(errors.New("dial tcp 127.0.0.1:8080: connect: connection refused")) {
t.Fatal("non proxy dial error should not be proxy")
}
}

49
errors_kind_test.go Normal file
View File

@ -0,0 +1,49 @@
package starnet
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
)
func TestIsCanceled(t *testing.T) {
if !IsCanceled(context.Canceled) {
t.Fatal("context canceled should be canceled")
}
if !IsCanceled(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatal("wrapped context canceled should be canceled")
}
if IsCanceled(context.DeadlineExceeded) {
t.Fatal("deadline exceeded must not be canceled")
}
}
func TestClassifyError(t *testing.T) {
dnsErr := &net.DNSError{Err: "no such host", Name: "example.invalid", IsNotFound: true}
tlsErr := tls.RecordHeaderError{Msg: "bad tls record"}
tests := []struct {
name string
err error
want ErrorKind
}{
{name: "nil", err: nil, want: ErrorKindNone},
{name: "canceled", err: context.Canceled, want: ErrorKindCanceled},
{name: "proxy", err: errors.New("proxyconnect tcp: dial tcp 127.0.0.1:8080: i/o timeout"), want: ErrorKindProxy},
{name: "dns", err: dnsErr, want: ErrorKindDNS},
{name: "tls", err: tlsErr, want: ErrorKindTLS},
{name: "timeout", err: context.DeadlineExceeded, want: ErrorKindTimeout},
{name: "other", err: errors.New("boom"), want: ErrorKindOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ClassifyError(tt.err); got != tt.want {
t.Fatalf("ClassifyError()=%s want=%s err=%v", got, tt.want, tt.err)
}
})
}
}

256
example_test.go Normal file
View File

@ -0,0 +1,256 @@
package starnet_test
import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"time"
"b612.me/starnet"
)
func ExampleGet() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, World!"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Hello, World!
}
func ExamplePost() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Posted"))
}))
defer server.Close()
resp, err := starnet.Post(server.URL,
starnet.WithBodyString("test data"))
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Posted
}
func ExampleNewSimpleRequest() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
req := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom", "value").
AddQuery("name", "test")
resp, err := req.Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}
func ExampleClient_Get() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Client GET"))
}))
defer server.Close()
client := starnet.NewClientNoErr(
starnet.WithTimeout(10*time.Second),
starnet.WithUserAgent("MyApp/1.0"),
)
resp, err := client.Get(server.URL)
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: Client GET
}
func ExampleRequest_SetJSON() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
type User struct {
Name string `json:"name"`
Email string `json:"email"`
}
user := User{Name: "John", Email: "john@example.com"}
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
SetJSON(user).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
fmt.Println(result["status"])
// Output: ok
}
func ExampleRequest_AddFormData() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
fmt.Fprintf(w, "name=%s", r.FormValue("name"))
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFormData("age", "30").
Do()
if err != nil {
panic(err)
}
defer resp.Close()
body, _ := resp.Body().String()
fmt.Println(body)
// Output: name=John
}
func ExampleRequest_SetSkipTLSVerify() {
// This example shows how to skip TLS verification
// Useful for testing with self-signed certificates
req := starnet.NewSimpleRequest("https://self-signed.example.com", "GET").
SetSkipTLSVerify(true)
// In a real scenario, you would call req.Do()
fmt.Println(req.Method())
// Output: GET
}
func ExampleRequest_Clone() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
defer server.Close()
baseReq := starnet.NewSimpleRequest(server.URL, "GET").
SetHeader("X-API-Key", "secret")
// Clone and modify
req1 := baseReq.Clone().AddQuery("page", "1")
req2 := baseReq.Clone().AddQuery("page", "2")
resp1, _ := req1.Do()
resp2, _ := req2.Do()
defer resp1.Close()
defer resp2.Close()
fmt.Println(resp1.StatusCode, resp2.StatusCode)
// Output: 200 200
}
func ExampleClient_SetDefaultSkipTLSVerify() {
client := starnet.NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
// All requests from this client will skip TLS verification
// unless overridden at request level
fmt.Println("Client configured")
// Output: Client configured
}
func ExampleWithTimeout() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.Write([]byte("OK"))
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithTimeout(200*time.Millisecond))
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode)
// Output: 200
}
func ExampleWithRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.Get(server.URL,
starnet.WithRetry(2,
starnet.WithRetryBackoff(0, 0, 1),
starnet.WithRetryJitter(0),
),
)
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 3
}
func ExampleRequest_SetRetry() {
var hits int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
resp, err := starnet.NewSimpleRequest(server.URL, http.MethodPost).
SetBodyString("hello"). // 可重放 body 才能安全重试
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
Do()
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode, atomic.LoadInt32(&hits))
// Output: 200 2
}

172
file_upload_test.go Normal file
View File

@ -0,0 +1,172 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
func TestRequestAddFileStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20) // 10 MB
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
fileContent := "test file content"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "test.txt:" + fileContent
if body != expected {
t.Errorf("Body = %v; want %v", body, expected)
}
}
func TestRequestAddFileWithFormData(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
// Check form field
name := r.FormValue("name")
if name != "John" {
t.Errorf("name = %v; want John", name)
}
// Check file
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
w.Write([]byte("OK:" + header.Filename))
}))
defer server.Close()
fileContent := "file data"
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
AddFormData("name", "John").
AddFileStream("file", "document.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, "document.txt") {
t.Errorf("Body should contain filename, got: %v", body)
}
}
func TestRequestUploadProgress(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
progressCalled := false
var lastUploaded int64
fileContent := strings.Repeat("a", 1024*10) // 10KB
reader := strings.NewReader(fileContent)
req := NewSimpleRequest(server.URL, "POST").
SetUploadProgress(func(filename string, uploaded, total int64) {
progressCalled = true
lastUploaded = uploaded
if filename != "test.txt" {
t.Errorf("filename = %v; want test.txt", filename)
}
}).
AddFileStream("file", "test.txt", int64(len(fileContent)), reader)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if !progressCalled {
t.Error("Progress callback was not called")
}
if lastUploaded != int64(len(fileContent)) {
t.Errorf("lastUploaded = %v; want %v", lastUploaded, len(fileContent))
}
}
// TestRequestAddFileFromDisk tests uploading a real file from disk
func TestRequestAddFileFromDisk(t *testing.T) {
// Create a temporary file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
fileContent := []byte("test file content from disk")
err := os.WriteFile(tmpFile, fileContent, 0644)
if err != nil {
t.Fatalf("WriteFile error: %v", err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(10 << 20)
if err != nil {
t.Fatalf("ParseMultipartForm error: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("FormFile error: %v", err)
}
defer file.Close()
content, _ := io.ReadAll(file)
w.Write([]byte(header.Filename + ":" + string(content)))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "POST").AddFile("file", tmpFile)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if !strings.Contains(body, string(fileContent)) {
t.Errorf("Body should contain file content, got: %v", body)
}
}

2
go.mod
View File

@ -1,5 +1,3 @@
module b612.me/starnet module b612.me/starnet
go 1.16 go 1.16
require b612.me/stario v0.0.8

13
go.sum
View File

@ -1,13 +0,0 @@
b612.me/stario v0.0.8 h1:kaA4pszAKLZJm2D9JmiuYSpgjTeE3VaO74vm+H0vBGM=
b612.me/stario v0.0.8/go.mod h1:or4ssWcxQSjMeu+hRKEgtp0X517b3zdlEOAms8Qscvw=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000 h1:SL+8VVnkqyshUSz5iNnXtrBQzvFF2SkROm6t5RczFAE=
golang.org/x/crypto v0.0.0-20220313003712-b769efc7c000/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

140
header_test.go Normal file
View File

@ -0,0 +1,140 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestHeaders(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := make(map[string]string)
for k, v := range r.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
json.NewEncoder(w).Encode(headers)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Custom-Header", "value1").
AddHeader("X-Multi-Header", "value1").
AddHeader("X-Multi-Header", "value2")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var headers map[string]string
resp.Body().JSON(&headers)
if headers["X-Custom-Header"] != "value1" {
t.Errorf("X-Custom-Header = %v; want value1", headers["X-Custom-Header"])
}
}
func TestRequestCookies(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookies := make(map[string]string)
for _, cookie := range r.Cookies() {
cookies[cookie.Name] = cookie.Value
}
json.NewEncoder(w).Encode(cookies)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddSimpleCookie("session", "abc123").
AddSimpleCookie("user", "john")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var cookies map[string]string
resp.Body().JSON(&cookies)
if cookies["session"] != "abc123" {
t.Errorf("session cookie = %v; want abc123", cookies["session"])
}
if cookies["user"] != "john" {
t.Errorf("user cookie = %v; want john", cookies["user"])
}
}
func TestRequestUserAgent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.UserAgent()))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetUserAgent("CustomAgent/1.0")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "CustomAgent/1.0" {
t.Errorf("User-Agent = %v; want CustomAgent/1.0", body)
}
}
func TestRequestBearerToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
w.Write([]byte(auth))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBearerToken("mytoken123")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
expected := "Bearer mytoken123"
if body != expected {
t.Errorf("Authorization = %v; want %v", body, expected)
}
}
func TestRequestBasicAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Write([]byte(username + ":" + password))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetBasicAuth("user", "pass")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "user:pass" {
t.Errorf("BasicAuth = %v; want user:pass", body)
}
}

258
integration_test.go Normal file
View File

@ -0,0 +1,258 @@
package starnet
import (
"os"
"testing"
"time"
)
// 这些测试使用 httpbin.org 作为测试服务
// 可以通过环境变量 STARNET_INTEGRATION_TEST=1 来启用
func skipIfNoIntegration(t *testing.T) {
if os.Getenv("STARNET_INTEGRATION_TEST") != "1" {
t.Skip("Skipping integration test. Set STARNET_INTEGRATION_TEST=1 to run")
}
}
func TestIntegrationHTTPBinGet(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/get",
WithQuery("name", "starnet"),
WithQuery("version", "1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
args, ok := result["args"].(map[string]interface{})
if !ok {
t.Fatal("args not found in response")
}
if args["name"] != "starnet" {
t.Errorf("args[name] = %v; want starnet", args["name"])
}
}
func TestIntegrationHTTPBinPost(t *testing.T) {
skipIfNoIntegration(t)
type PostData struct {
Name string `json:"name"`
Email string `json:"email"`
}
data := PostData{
Name: "John Doe",
Email: "john@example.com",
}
resp, err := Post("https://httpbin.org/post", WithJSON(data))
if err != nil {
t.Fatalf("Post() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
jsonData, ok := result["json"].(map[string]interface{})
if !ok {
t.Fatal("json not found in response")
}
if jsonData["name"] != data.Name {
t.Errorf("name = %v; want %v", jsonData["name"], data.Name)
}
}
func TestIntegrationHTTPBinHeaders(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/headers",
WithHeader("X-Custom-Header", "test-value"),
WithUserAgent("Starnet-Test/1.0"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
headers, ok := result["headers"].(map[string]interface{})
if !ok {
t.Fatal("headers not found in response")
}
if headers["X-Custom-Header"] != "test-value" {
t.Errorf("X-Custom-Header = %v; want test-value", headers["X-Custom-Header"])
}
}
func TestIntegrationHTTPBinBasicAuth(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/basic-auth/user/passwd",
WithBasicAuth("user", "passwd"))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200", resp.StatusCode)
}
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["authenticated"] != true {
t.Error("authenticated should be true")
}
}
func TestIntegrationHTTPBinDelay(t *testing.T) {
skipIfNoIntegration(t)
// Test timeout
start := time.Now()
_, err := Get("https://httpbin.org/delay/3",
WithTimeout(1*time.Second))
elapsed := time.Since(start)
if err == nil {
t.Error("Expected timeout error, got nil")
}
if elapsed > 2*time.Second {
t.Errorf("Timeout took too long: %v", elapsed)
}
}
func TestIntegrationHTTPBinRedirect(t *testing.T) {
skipIfNoIntegration(t)
// Test with redirect enabled
client := NewClientNoErr()
resp, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != 200 {
t.Errorf("StatusCode = %v; want 200 (after redirect)", resp.StatusCode)
}
// Test with redirect disabled
client.DisableRedirect()
resp2, err := client.Get("https://httpbin.org/redirect/2")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != 302 {
t.Errorf("StatusCode = %v; want 302 (redirect disabled)", resp2.StatusCode)
}
}
func TestIntegrationHTTPBinCookies(t *testing.T) {
skipIfNoIntegration(t)
// 创建一个禁用重定向的 Client
client := NewClientNoErr()
client.DisableRedirect()
resp, err := client.Get("https://httpbin.org/cookies/set?name=value")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// 现在应该能获取到 Set-Cookie
cookies := resp.Cookies()
if len(cookies) == 0 {
t.Error("Expected cookies in response")
}
// 验证 cookie
found := false
for _, cookie := range cookies {
if cookie.Name == "name" && cookie.Value == "value" {
found = true
break
}
}
if !found {
t.Error("Expected cookie 'name=value' not found")
}
}
func TestIntegrationHTTPBinUserAgent(t *testing.T) {
skipIfNoIntegration(t)
customUA := "Starnet-Integration-Test/1.0"
resp, err := Get("https://httpbin.org/user-agent",
WithUserAgent(customUA))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["user-agent"] != customUA {
t.Errorf("user-agent = %v; want %v", result["user-agent"], customUA)
}
}
func TestIntegrationHTTPBinGzip(t *testing.T) {
skipIfNoIntegration(t)
resp, err := Get("https://httpbin.org/gzip")
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result map[string]interface{}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("JSON() error: %v", err)
}
if result["gzipped"] != true {
t.Error("Response should be gzipped")
}
}

405
options.go Normal file
View File

@ -0,0 +1,405 @@
package starnet
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"time"
)
// WithTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func WithTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
// WithDialTimeout 设置连接超时时间
func WithDialTimeout(timeout time.Duration) RequestOpt {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
// WithProxy 设置代理
func WithProxy(proxy string) RequestOpt {
return func(r *Request) error {
r.config.Network.Proxy = proxy
return nil
}
}
// WithDialFunc 设置自定义 Dial 函数
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
// WithTLSConfig 设置 TLS 配置
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
// WithSkipTLSVerify 设置是否跳过 TLS 验证
func WithSkipTLSVerify(skip bool) RequestOpt {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
// WithCustomIP 设置自定义 IP
func WithCustomIP(ips []string) RequestOpt {
return func(r *Request) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
r.config.DNS.CustomIP = ips
return nil
}
}
// WithAddCustomIP 添加自定义 IP
func WithAddCustomIP(ip string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
// WithCustomDNS 设置自定义 DNS 服务器
func WithCustomDNS(dnsServers []string) RequestOpt {
return func(r *Request) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
r.config.DNS.CustomDNS = dnsServers
return nil
}
}
// WithAddCustomDNS 添加自定义 DNS 服务器
func WithAddCustomDNS(dns string) RequestOpt {
return func(r *Request) error {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
// WithLookupFunc 设置自定义 DNS 解析函数
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
// WithHeader 设置 Header
func WithHeader(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set(key, value)
return nil
}
}
// WithHeaders 批量设置 Headers
func WithHeaders(headers map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range headers {
r.config.Headers.Set(k, v)
}
return nil
}
}
// WithContentType 设置 Content-Type
func WithContentType(contentType string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Content-Type", contentType)
return nil
}
}
// WithUserAgent 设置 User-Agent
func WithUserAgent(userAgent string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("User-Agent", userAgent)
return nil
}
}
// WithBearerToken 设置 Bearer Token
func WithBearerToken(token string) RequestOpt {
return func(r *Request) error {
r.config.Headers.Set("Authorization", "Bearer "+token)
return nil
}
}
// WithBasicAuth 设置 Basic 认证
func WithBasicAuth(username, password string) RequestOpt {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
// WithCookie 添加 Cookie
func WithCookie(name, value, path string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: path,
})
return nil
}
}
// WithSimpleCookie 添加简单 Cookiepath 为 /
func WithSimpleCookie(name, value string) RequestOpt {
return func(r *Request) error {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
return nil
}
}
// WithCookies 批量添加 Cookies
func WithCookies(cookies map[string]string) RequestOpt {
return func(r *Request) error {
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return nil
}
}
// WithBody 设置请求体(字节)
func WithBody(body []byte) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return nil
}
}
// WithBodyString 设置请求体(字符串)
func WithBodyString(body string) RequestOpt {
return func(r *Request) error {
r.config.Body.Bytes = []byte(body)
r.config.Body.Reader = nil
return nil
}
}
// WithBodyReader 设置请求体Reader
func WithBodyReader(reader io.Reader) RequestOpt {
return func(r *Request) error {
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return nil
}
}
// WithJSON 设置 JSON 请求体
func WithJSON(v interface{}) RequestOpt {
return func(r *Request) error {
data, err := json.Marshal(v)
if err != nil {
return wrapError(err, "marshal json")
}
r.config.Headers.Set("Content-Type", ContentTypeJSON)
r.config.Body.Bytes = data
r.config.Body.Reader = nil
return nil
}
}
// WithFormData 设置表单数据
func WithFormData(data map[string][]string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData = data
return nil
}
}
// WithFormDataMap 设置表单数据(简化版)
func WithFormDataMap(data map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range data {
r.config.Body.FormData[k] = []string{v}
}
return nil
}
}
// WithAddFormData 添加表单数据
func WithAddFormData(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return nil
}
}
// WithFile 添加文件
func WithFile(formName, filePath string) RequestOpt {
return func(r *Request) error {
stat, err := os.Stat(filePath)
if err != nil {
return wrapError(ErrFileNotFound, "file: %s", filePath)
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithFileStream 添加文件流
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
return func(r *Request) error {
if reader == nil {
return ErrNilReader
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return nil
}
}
// WithQuery 添加查询参数
func WithQuery(key, value string) RequestOpt {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
// WithQueries 批量添加查询参数
func WithQueries(queries map[string]string) RequestOpt {
return func(r *Request) error {
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return nil
}
}
// WithContentLength 设置 Content-Length
func WithContentLength(length int64) RequestOpt {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
func WithAutoCalcContentLength(auto bool) RequestOpt {
return func(r *Request) error {
r.config.AutoCalcContentLength = auto
return nil
}
}
// WithUploadProgress 设置文件上传进度回调
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
// WithTransport 设置自定义 Transport
func WithTransport(transport *http.Transport) RequestOpt {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
// WithAutoFetch 设置是否自动获取响应体
func WithAutoFetch(auto bool) RequestOpt {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
// WithMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func WithMaxRespBodyBytes(maxBytes int64) RequestOpt {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
// WithRawRequest 设置原始请求
func WithRawRequest(httpReq *http.Request) RequestOpt {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.doRaw = true
return nil
}
}
// WithContext 设置 context
func WithContext(ctx context.Context) RequestOpt {
return func(r *Request) error {
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return nil
}
}

234
options_test.go Normal file
View File

@ -0,0 +1,234 @@
package starnet
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithJSONOpt(t *testing.T) {
type payload struct {
Name string `json:"name"`
Age int `json:"age"`
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
t.Fatalf("content-type=%s", ct)
}
var p payload
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
t.Fatalf("decode err: %v", err)
}
if p.Name != "alice" || p.Age != 18 {
t.Fatalf("payload mismatch: %+v", p)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileOpt(t *testing.T) {
// temp file + cleanup
f, err := os.CreateTemp("", "starnet-upload-*.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
_, _ = f.WriteString("hello-file")
_ = f.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename == "" || string(b) != "hello-file" {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFile("file", f.Name()))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileStreamOpt(t *testing.T) {
content := "stream-content"
reader := strings.NewReader(content)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("up")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename != "a.txt" || string(b) != content {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithQueryOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("k") != "v" {
t.Fatalf("query mismatch: %v", r.URL.Query())
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithQuery("k", "v"))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithUploadProgressOpt(t *testing.T) {
var called int32
var last int64
content := strings.Repeat("x", 4096)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL,
WithUploadProgress(func(filename string, uploaded, total int64) {
atomic.StoreInt32(&called, 1)
last = uploaded
}),
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
)
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
if atomic.LoadInt32(&called) == 0 {
t.Fatal("progress not called")
}
if last != int64(len(content)) {
t.Fatalf("last uploaded=%d want=%d", last, len(content))
}
}
func TestWithTransportOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithContextOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := Get(s.URL, WithContext(ctx))
if err == nil {
t.Fatal("expected context timeout error")
}
}
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomDNS) != 2 {
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
}
}
func TestWithAddCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
}
}
func TestWithCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 2 {
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
}
}
func TestWithDialFuncOpt(t *testing.T) {
called := int32(0)
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.StoreInt32(&called, 1)
return nil, io.EOF
}
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
if req.config.Network.DialFunc == nil {
t.Fatal("dial func not set")
}
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
if atomic.LoadInt32(&called) == 0 {
t.Fatal("dial func not called")
}
}
func TestWithDialTimeoutOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
if req.config.Network.DialTimeout != 123*time.Millisecond {
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
}
}

522
ping.go
View File

@ -1,12 +1,31 @@
package starnet package starnet
import ( import (
"bytes" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"net" "net"
"os"
"strings"
"sync/atomic"
"time" "time"
) )
const (
icmpTypeEchoReplyV4 = 0
icmpTypeEchoRequestV4 = 8
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpHeaderLen = 8
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
defaultPingableCount = 3
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP struct { type ICMP struct {
Type uint8 Type uint8
Code uint8 Code uint8
@ -15,52 +34,126 @@ type ICMP struct {
SequenceNum uint16 SequenceNum uint16
} }
func getICMP(seq uint16) ICMP { type pingSocketSpec struct {
network string
family int
requestType uint8
replyType uint8
}
// PingOptions controls ping probing behavior.
type PingOptions struct {
Count int // ping attempts for Pingable, default 3
Timeout time.Duration // per-attempt timeout, default 2s
Interval time.Duration // delay between attempts, default 0
Deadline time.Time // overall deadline for Pingable/PingWithContext
PreferIPv4 bool // prefer IPv4 targets
PreferIPv6 bool // prefer IPv6 targets
SourceIP net.IP // optional source IP for raw socket bind
PayloadSize int // ICMP payload bytes, default 0
}
type PingResult struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var pingIdentifierSeed uint32
func nextPingIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&pingIdentifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func pingPayload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for i := 0; i < len(payload); i++ {
payload[i] = byte(i)
}
return payload
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{ icmp := ICMP{
Type: 8, Type: typ,
Code: 0, Code: 0,
CheckSum: 0, CheckSum: 0,
Identifier: 0, Identifier: identifier,
SequenceNum: seq, SequenceNum: seq,
} }
var buffer bytes.Buffer buf := marshalICMPPacket(icmp, payload)
binary.Write(&buffer, binary.BigEndian, icmp) icmp.CheckSum = checkSum(buf)
icmp.CheckSum = checkSum(buffer.Bytes())
buffer.Reset()
return icmp return icmp
} }
func sendICMPRequest(icmp ICMP, destAddr *net.IPAddr, timeout time.Duration) (PingResult, error) { func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
var res PingResult var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
if destAddr == nil || destAddr.IP == nil {
return res, fmt.Errorf("destination ip is nil")
}
res.RemoteIP = destAddr.String() res.RemoteIP = destAddr.String()
conn, err := net.DialIP("ip:icmp", nil, destAddr)
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
if err != nil { if err != nil {
return res, err return res, err
} }
defer conn.Close()
var buffer bytes.Buffer
binary.Write(&buffer, binary.BigEndian, icmp)
if _, err := conn.Write(buffer.Bytes()); err != nil { conn, err := net.DialIP(spec.network, localAddr, destAddr)
return res, err if err != nil {
return res, normalizePingDialError(err)
}
defer conn.Close()
packet := marshalICMPPacket(icmp, payload)
if _, err := conn.Write(packet); err != nil {
return res, wrapError(err, "ping write request")
} }
tStart := time.Now() tStart := time.Now()
deadline := tStart.Add(timeout)
conn.SetReadDeadline((time.Now().Add(timeout))) if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
recv := make([]byte, 1024) }
res.RecvCount, err = conn.Read(recv) if err := conn.SetReadDeadline(deadline); err != nil {
return res, wrapError(err, "ping set read deadline")
if err != nil {
return res, err
} }
tEnd := time.Now() doneCh := make(chan struct{})
res.Duration = tEnd.Sub(tStart) go func() {
select {
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
case <-doneCh:
}
}()
defer close(doneCh)
return res, err recv := make([]byte, icmpReadBufSz)
for {
n, err := conn.Read(recv)
if err != nil {
if ctx.Err() != nil {
return res, wrapError(ctx.Err(), "ping context done")
}
return res, wrapError(err, "ping read reply")
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(tStart)
return res, nil
}
}
} }
func checkSum(data []byte) uint16 { func checkSum(data []byte) uint16 {
@ -75,36 +168,375 @@ func checkSum(data []byte) uint16 {
length -= 2 length -= 2
} }
if length > 0 { if length > 0 {
sum += uint32(data[index]) sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
} }
sum += (sum >> 16)
return uint16(^sum) return uint16(^sum)
} }
type PingResult struct { func marshalICMP(icmp ICMP) []byte {
Duration time.Duration return marshalICMPPacket(icmp, nil)
RecvCount int
RemoteIP string
} }
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
var res PingResult buf := make([]byte, icmpHeaderLen+len(payload))
ipAddr, err := net.ResolveIPAddr("ip", ip) buf[0] = icmp.Type
if err != nil { buf[1] = icmp.Code
return res, err binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
} binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
icmp := getICMP(uint16(seq)) binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
return sendICMPRequest(icmp, ipAddr, timeout) copy(buf[icmpHeaderLen:], payload)
return buf
} }
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for i := 0; i < retryLimit; i++ { for _, off := range candidateICMPOffsets(packet, family) {
_, err := Ping(ip, 29, timeout) if off < 0 || off+icmpHeaderLen > len(packet) {
if err != nil { continue
}
if packet[off] != expectedType || packet[off+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
continue continue
} }
return true return true
} }
return false return false
} }
func candidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
ver := packet[0] >> 4
if ver == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return dedupOffsets(offsets)
}
func dedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
m := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, off := range offsets {
if _, ok := m[off]; ok {
continue
}
m[off] = struct{}{}
out = append(out, off)
}
return out
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
if ip == nil {
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
}
if ip4 := ip.To4(); ip4 != nil {
return pingSocketSpec{
network: "ip4:icmp",
family: 4,
requestType: icmpTypeEchoRequestV4,
replyType: icmpTypeEchoReplyV4,
}, nil
}
if ip16 := ip.To16(); ip16 != nil {
return pingSocketSpec{
network: "ip6:ipv6-icmp",
family: 6,
requestType: icmpTypeEchoRequestV6,
replyType: icmpTypeEchoReplyV6,
}, nil
}
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
}
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
if sourceIP == nil {
return nil, nil
}
if sourceIP.To16() == nil {
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
}
if family == 4 && sourceIP.To4() == nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
}
if family == 6 && sourceIP.To4() != nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
}
return &net.IPAddr{IP: sourceIP}, nil
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = e
}
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = e
}
if len(targets) > 0 {
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, ErrPingNoResolvedTarget
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
return ordered
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
return ordered
}
func normalizePingDialError(err error) error {
if err == nil {
return nil
}
msg := strings.ToLower(err.Error())
if errors.Is(err, os.ErrPermission) ||
strings.Contains(msg, "operation not permitted") ||
strings.Contains(msg, "permission denied") {
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
}
if strings.Contains(msg, "unknown network") ||
strings.Contains(msg, "protocol not available") ||
strings.Contains(msg, "address family not supported by protocol") ||
strings.Contains(msg, "socket type not supported") {
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
}
return wrapError(err, "ping dial")
}
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
out := PingOptions{
Count: defaultCount,
Timeout: defaultTimeout,
Interval: 0,
PayloadSize: 0,
}
if opts != nil {
out = *opts
if out.Count == 0 {
out.Count = defaultCount
}
if out.Timeout == 0 {
out.Timeout = defaultTimeout
}
}
if out.Count < 0 {
return out, fmt.Errorf("ping count must be >= 0")
}
if out.Timeout <= 0 {
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
}
if out.Interval < 0 {
return out, fmt.Errorf("ping interval must be >= 0")
}
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
}
if out.SourceIP != nil && out.SourceIP.To16() == nil {
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
}
return out, nil
}
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
if err != nil {
return res, wrapError(err, "resolve ping target")
}
payload := pingPayload(opts.PayloadSize)
var lastErr error
for _, target := range targets {
spec, err := socketSpecForIP(target.IP)
if err != nil {
lastErr = err
continue
}
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
if err == nil {
return resp, nil
}
// 权限问题通常与地址族无关,继续重试意义不大。
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
lastErr = err
}
if lastErr != nil {
return res, wrapError(lastErr, "ping all resolved targets failed")
}
return res, ErrPingNoResolvedTarget
}
// PingWithContext sends one ICMP echo request with context cancel support.
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
opts, err := normalizePingOptions(&PingOptions{
Count: 1,
Timeout: timeout,
}, 1, timeout)
if err != nil {
return PingResult{}, err
}
if !opts.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
defer cancel()
}
return pingOnceWithOptions(ctx, host, seq, opts)
}
// Ping sends one ICMP echo request.
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
return PingWithContext(context.Background(), ip, seq, timeout)
}
// Pingable checks host reachability with retry options.
func Pingable(host string, opts *PingOptions) (bool, error) {
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
if err != nil {
return false, err
}
ctx := context.Background()
if !cfg.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
defer cancel()
}
var lastErr error
for i := 0; i < cfg.Count; i++ {
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
if err == nil {
return true, nil
}
lastErr = err
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
}
if i < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():
timer.Stop()
return false, wrapError(ctx.Err(), "pingable context done")
case <-timer.C:
}
}
}
if lastErr == nil {
lastErr = ErrPingNoResolvedTarget
}
return false, lastErr
}
// IsIpPingable keeps backward-compatible bool-only behavior.
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
if retryLimit <= 0 {
return false
}
ok, _ := Pingable(ip, &PingOptions{
Count: retryLimit,
Timeout: timeout,
})
return ok
}

214
ping_logic_test.go Normal file
View File

@ -0,0 +1,214 @@
package starnet
import (
"context"
"errors"
"net"
"testing"
"time"
)
func buildICMPPacket(typ uint8, identifier, seq uint16) []byte {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, nil)
cs := checkSum(buf)
buf[2] = byte(cs >> 8)
buf[3] = byte(cs)
return buf
}
func TestNextPingIdentifierChanges(t *testing.T) {
id1 := nextPingIdentifier()
id2 := nextPingIdentifier()
if id1 == id2 {
t.Fatalf("identifier should change between calls: %d == %d", id1, id2)
}
}
func TestIsExpectedEchoReplyIPv4(t *testing.T) {
identifier := uint16(0x1234)
seq := uint16(0x0102)
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
if !isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 echo reply to match")
}
if isExpectedEchoReply(reply, 4, icmpTypeEchoReplyV4, identifier, seq+1) {
t.Fatal("mismatched sequence should not match")
}
}
func TestIsExpectedEchoReplyIPv4WithIPHeader(t *testing.T) {
identifier := uint16(0x1111)
seq := uint16(0x2222)
ipHeader := make([]byte, 20)
ipHeader[0] = 0x45 // version=4, ihl=5
reply := buildICMPPacket(icmpTypeEchoReplyV4, identifier, seq)
packet := append(ipHeader, reply...)
if !isExpectedEchoReply(packet, 4, icmpTypeEchoReplyV4, identifier, seq) {
t.Fatal("expected IPv4 packet with header to match")
}
}
func TestIsExpectedEchoReplyIPv6WithHeader(t *testing.T) {
identifier := uint16(0xabcd)
seq := uint16(0x00ff)
ipv6Header := make([]byte, 40)
ipv6Header[0] = 0x60 // version=6
reply := buildICMPPacket(icmpTypeEchoReplyV6, identifier, seq)
packet := append(ipv6Header, reply...)
if !isExpectedEchoReply(packet, 6, icmpTypeEchoReplyV6, identifier, seq) {
t.Fatal("expected IPv6 packet with header to match")
}
}
func TestPingInvalidTimeout(t *testing.T) {
_, err := Ping("127.0.0.1", 1, 0)
if err == nil {
t.Fatal("expected error for non-positive timeout")
}
if !errors.Is(err, ErrPingInvalidTimeout) {
t.Fatalf("expected ErrPingInvalidTimeout, got: %v", err)
}
}
func TestIsIPPingableInvalidRetry(t *testing.T) {
if IsIpPingable("127.0.0.1", time.Millisecond, 0) {
t.Fatal("retryLimit=0 should return false")
}
}
func TestSocketSpecForIP(t *testing.T) {
v4, err := socketSpecForIP(net.ParseIP("127.0.0.1"))
if err != nil {
t.Fatalf("unexpected v4 error: %v", err)
}
if v4.network != "ip4:icmp" || v4.family != 4 || v4.requestType != icmpTypeEchoRequestV4 || v4.replyType != icmpTypeEchoReplyV4 {
t.Fatalf("unexpected v4 spec: %+v", v4)
}
v6, err := socketSpecForIP(net.ParseIP("::1"))
if err != nil {
t.Fatalf("unexpected v6 error: %v", err)
}
if v6.network != "ip6:ipv6-icmp" || v6.family != 6 || v6.requestType != icmpTypeEchoRequestV6 || v6.replyType != icmpTypeEchoReplyV6 {
t.Fatalf("unexpected v6 spec: %+v", v6)
}
_, err = socketSpecForIP(nil)
if err == nil {
t.Fatal("expected error for nil ip")
}
if !errors.Is(err, ErrInvalidIP) {
t.Fatalf("expected ErrInvalidIP, got: %v", err)
}
}
func TestResolvePingTargetsLiteral(t *testing.T) {
v4, err := resolvePingTargets("127.0.0.1", false, false)
if err != nil {
t.Fatalf("unexpected v4 resolve error: %v", err)
}
if len(v4) != 1 || v4[0] == nil || v4[0].IP == nil || v4[0].IP.To4() == nil {
t.Fatalf("unexpected v4 targets: %+v", v4)
}
v6, err := resolvePingTargets("::1", false, false)
if err != nil {
t.Fatalf("unexpected v6 resolve error: %v", err)
}
if len(v6) != 1 || v6[0] == nil || v6[0].IP == nil || v6[0].IP.To16() == nil || v6[0].IP.To4() != nil {
t.Fatalf("unexpected v6 targets: %+v", v6)
}
}
func TestNormalizePingDialError(t *testing.T) {
perr := normalizePingDialError(errors.New("socket: operation not permitted"))
if !errors.Is(perr, ErrPingPermissionDenied) {
t.Fatalf("expected ErrPingPermissionDenied, got: %v", perr)
}
uerr := normalizePingDialError(errors.New("unknown network ip6:ipv6-icmp"))
if !errors.Is(uerr, ErrPingProtocolUnsupported) {
t.Fatalf("expected ErrPingProtocolUnsupported, got: %v", uerr)
}
}
func TestOrderPingTargets(t *testing.T) {
targets := []*net.IPAddr{
{IP: net.ParseIP("::1")},
{IP: net.ParseIP("127.0.0.1")},
}
v4First := orderPingTargets(targets, true, false)
if v4First[0].IP.To4() == nil {
t.Fatalf("expected IPv4 first, got: %v", v4First[0].IP)
}
v6First := orderPingTargets(targets, false, true)
if v6First[0].IP.To4() != nil {
t.Fatalf("expected IPv6 first, got: %v", v6First[0].IP)
}
}
func TestNormalizePingOptions(t *testing.T) {
opts, err := normalizePingOptions(nil, 3, 2*time.Second)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if opts.Count != 3 || opts.Timeout != 2*time.Second {
t.Fatalf("unexpected defaults: %+v", opts)
}
_, err = normalizePingOptions(&PingOptions{Count: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative count")
}
_, err = normalizePingOptions(&PingOptions{Timeout: -1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for negative timeout")
}
_, err = normalizePingOptions(&PingOptions{PayloadSize: maxPingPayloadSize + 1}, 3, 2*time.Second)
if err == nil {
t.Fatal("expected error for too large payload")
}
}
func TestPingWithContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := PingWithContext(ctx, "127.0.0.1", 1, time.Second)
if err == nil {
t.Fatal("expected canceled error")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got: %v", err)
}
}
func TestPingableInvalidOptions(t *testing.T) {
_, err := Pingable("127.0.0.1", &PingOptions{Count: -1})
if err == nil {
t.Fatal("expected invalid count error")
}
_, err = Pingable("127.0.0.1", &PingOptions{Interval: -1})
if err == nil {
t.Fatal("expected invalid interval error")
}
}

50
proxy_test.go Normal file
View File

@ -0,0 +1,50 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestProxy(t *testing.T) {
// Create a proxy server
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Proxy received the request
w.Header().Set("X-Proxied", "true")
w.WriteHeader(http.StatusOK)
w.Write([]byte("proxied"))
}))
defer proxyServer.Close()
// Note: This is a simplified test. Real proxy testing requires more setup
req := NewSimpleRequest("http://example.com", "GET").
SetProxy(proxyServer.URL)
// Just verify the proxy is set in config
if req.config.Network.Proxy != proxyServer.URL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyServer.URL)
}
}
func TestClientLevelProxy(t *testing.T) {
proxyURL := "http://proxy.example.com:8080"
client := NewClientNoErr(WithProxy(proxyURL))
req, _ := client.NewRequest("http://example.com", "GET")
if req.config.Network.Proxy != proxyURL {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, proxyURL)
}
}
func TestRequestLevelProxyOverride(t *testing.T) {
clientProxy := "http://client-proxy.com:8080"
requestProxy := "http://request-proxy.com:8080"
client := NewClientNoErr(WithProxy(clientProxy))
req, _ := client.NewRequest("http://example.com", "GET", WithProxy(requestProxy))
// Request level should override client level
if req.config.Network.Proxy != requestProxy {
t.Errorf("Proxy = %v; want %v", req.config.Network.Proxy, requestProxy)
}
}

317
que.go
View File

@ -1,317 +0,0 @@
package starnet
import (
"bytes"
"context"
"encoding/binary"
"errors"
"os"
"sync"
"time"
)
// 识别头
var header = []byte{11, 27, 19, 96, 12, 25, 02, 20}
// MsgQueue 为基本的信息单位
type MsgQueue struct {
ID uint16
Msg []byte
Conn interface{}
}
// StarQueue 为流数据中的消息队列分发
type StarQueue struct {
count int64
Encode bool
Reserve uint16
Msgid uint16
MsgPool chan MsgQueue
UnFinMsg sync.Map
LastID int //= -1
ctx context.Context
cancel context.CancelFunc
duration time.Duration
EncodeFunc func([]byte) []byte
DecodeFunc func([]byte) []byte
//restoreMu sync.Mutex
}
func NewQueueCtx(ctx context.Context, count int64) *StarQueue {
var que StarQueue
que.Encode = false
que.count = count
que.MsgPool = make(chan MsgQueue, count)
if ctx == nil {
que.ctx, que.cancel = context.WithCancel(context.Background())
} else {
que.ctx, que.cancel = context.WithCancel(ctx)
}
que.duration = 0
return &que
}
func NewQueueWithCount(count int64) *StarQueue {
return NewQueueCtx(nil, count)
}
// NewQueue 建立一个新消息队列
func NewQueue() *StarQueue {
return NewQueueWithCount(32)
}
// Uint32ToByte 4位uint32转byte
func Uint32ToByte(src uint32) []byte {
res := make([]byte, 4)
res[3] = uint8(src)
res[2] = uint8(src >> 8)
res[1] = uint8(src >> 16)
res[0] = uint8(src >> 24)
return res
}
// ByteToUint32 byte转4位uint32
func ByteToUint32(src []byte) uint32 {
var res uint32
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// Uint16ToByte 2位uint16转byte
func Uint16ToByte(src uint16) []byte {
res := make([]byte, 2)
res[1] = uint8(src)
res[0] = uint8(src >> 8)
return res
}
// ByteToUint16 用于byte转uint16
func ByteToUint16(src []byte) uint16 {
var res uint16
buffer := bytes.NewBuffer(src)
binary.Read(buffer, binary.BigEndian, &res)
return res
}
// BuildMessage 生成编码后的信息用于发送
func (que *StarQueue) BuildMessage(src []byte) []byte {
var buff bytes.Buffer
que.Msgid++
if que.Encode {
src = que.EncodeFunc(src)
}
length := uint32(len(src))
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
buff.Write(src)
return buff.Bytes()
}
// BuildHeader 生成编码后的Header用于发送
func (que *StarQueue) BuildHeader(length uint32) []byte {
var buff bytes.Buffer
que.Msgid++
buff.Write(header)
buff.Write(Uint32ToByte(length))
buff.Write(Uint16ToByte(que.Msgid))
return buff.Bytes()
}
type unFinMsg struct {
ID uint16
LengthRecv uint32
// HeaderMsg 信息头应当为14位8位识别码+4位长度码+2位id
HeaderMsg []byte
RecvMsg []byte
}
func (que *StarQueue) push2list(msg MsgQueue) {
que.MsgPool <- msg
}
// ParseMessage 用于解析收到的msg信息
func (que *StarQueue) ParseMessage(msg []byte, conn interface{}) error {
return que.parseMessage(msg, conn)
}
// parseMessage 用于解析收到的msg信息
func (que *StarQueue) parseMessage(msg []byte, conn interface{}) error {
tmp, ok := que.UnFinMsg.Load(conn)
if ok { //存在未完成的信息
lastMsg := tmp.(*unFinMsg)
headerLen := len(lastMsg.HeaderMsg)
if headerLen < 14 { //未完成头标题
//传输的数据不能填充header头
if len(msg) < 14-headerLen {
//加入header头并退出
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
//获取14字节完整的header
header := msg[0 : 14-headerLen]
lastMsg.HeaderMsg = bytesMerge(lastMsg.HeaderMsg, header)
//检查收到的header是否为认证header
//若不是,丢弃并重新来过
if !checkHeader(lastMsg.HeaderMsg[0:8]) {
que.UnFinMsg.Delete(conn)
if len(msg) == 0 {
return nil
}
return que.parseMessage(msg, conn)
}
//获得本数据包长度
lastMsg.LengthRecv = ByteToUint32(lastMsg.HeaderMsg[8:12])
//获得本数据包ID
lastMsg.ID = ByteToUint16(lastMsg.HeaderMsg[12:14])
//存入列表
que.UnFinMsg.Store(conn, lastMsg)
msg = msg[14-headerLen:]
if uint32(len(msg)) < lastMsg.LengthRecv {
lastMsg.RecvMsg = msg
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
if uint32(len(msg)) >= lastMsg.LengthRecv {
lastMsg.RecvMsg = msg[0:lastMsg.LengthRecv]
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
msg = msg[lastMsg.LengthRecv:]
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
//que.restoreMu.Lock()
que.push2list(storeMsg)
//que.restoreMu.Unlock()
que.UnFinMsg.Delete(conn)
return que.parseMessage(msg, conn)
}
} else {
lastID := int(lastMsg.LengthRecv) - len(lastMsg.RecvMsg)
if lastID < 0 {
que.UnFinMsg.Delete(conn)
return que.parseMessage(msg, conn)
}
if len(msg) >= lastID {
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg[0:lastID])
if que.Encode {
lastMsg.RecvMsg = que.DecodeFunc(lastMsg.RecvMsg)
}
storeMsg := MsgQueue{
ID: lastMsg.ID,
Msg: lastMsg.RecvMsg,
Conn: conn,
}
que.push2list(storeMsg)
que.UnFinMsg.Delete(conn)
if len(msg) == lastID {
return nil
}
msg = msg[lastID:]
return que.parseMessage(msg, conn)
}
lastMsg.RecvMsg = bytesMerge(lastMsg.RecvMsg, msg)
que.UnFinMsg.Store(conn, lastMsg)
return nil
}
}
if len(msg) == 0 {
return nil
}
var start int
if start = searchHeader(msg); start == -1 {
return errors.New("data format error")
}
msg = msg[start:]
lastMsg := unFinMsg{}
que.UnFinMsg.Store(conn, &lastMsg)
return que.parseMessage(msg, conn)
}
func checkHeader(msg []byte) bool {
if len(msg) != 8 {
return false
}
for k, v := range msg {
if v != header[k] {
return false
}
}
return true
}
func searchHeader(msg []byte) int {
if len(msg) < 8 {
return 0
}
for k, v := range msg {
find := 0
if v == header[0] {
for k2, v2 := range header {
if msg[k+k2] == v2 {
find++
} else {
break
}
}
if find == 8 {
return k
}
}
}
return -1
}
func bytesMerge(src ...[]byte) []byte {
var buff bytes.Buffer
for _, v := range src {
buff.Write(v)
}
return buff.Bytes()
}
// Restore 获取收到的信息
func (que *StarQueue) Restore() (MsgQueue, error) {
if que.duration.Seconds() == 0 {
que.duration = 86400 * time.Second
}
for {
select {
case <-que.ctx.Done():
return MsgQueue{}, errors.New("Stoped By External Function Call")
case <-time.After(que.duration):
if que.duration != 0 {
return MsgQueue{}, os.ErrDeadlineExceeded
}
case data, ok := <-que.MsgPool:
if !ok {
return MsgQueue{}, os.ErrClosed
}
return data, nil
}
}
}
// RestoreOne 获取收到的一个信息
//兼容性修改
func (que *StarQueue) RestoreOne() (MsgQueue, error) {
return que.Restore()
}
// Stop 立即停止Restore
func (que *StarQueue) Stop() {
que.cancel()
}
// RestoreDuration Restore最大超时时间
func (que *StarQueue) RestoreDuration(tm time.Duration) {
que.duration = tm
}
func (que *StarQueue) RestoreChan() <-chan MsgQueue {
return que.MsgPool
}

View File

@ -1,42 +0,0 @@
package starnet
import (
"fmt"
"testing"
"time"
)
func Test_QueSpeed(t *testing.T) {
que := NewQueueWithCount(0)
stop := make(chan struct{}, 1)
que.RestoreDuration(time.Second * 10)
var count int64
go func() {
for {
select {
case <-stop:
//fmt.Println(count)
return
default:
}
_, err := que.RestoreOne()
if err == nil {
count++
}
}
}()
cp := 0
stoped := time.After(time.Second * 10)
data := que.BuildMessage([]byte("hello"))
for {
select {
case <-stoped:
fmt.Println(count, cp)
stop <- struct{}{}
return
default:
que.ParseMessage(data, "lala")
cp++
}
}
}

98
query_test.go Normal file
View File

@ -0,0 +1,98 @@
package starnet
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string][]string)
for k, v := range query {
result[k] = v
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("name", "John").
AddQuery("age", "30").
AddQuery("tags", "go").
AddQuery("tags", "http")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string][]string
resp.Body().JSON(&result)
if len(result["name"]) != 1 || result["name"][0] != "John" {
t.Errorf("name = %v; want [John]", result["name"])
}
if len(result["tags"]) != 2 {
t.Errorf("tags length = %v; want 2", len(result["tags"]))
}
}
func TestRequestSetQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
w.Write([]byte(query.Get("key")))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetQuery("key", "value1").
SetQuery("key", "value2") // Should overwrite
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "value2" {
t.Errorf("query value = %v; want value2", body)
}
}
func TestRequestDeleteQuery(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
result := make(map[string]string)
for k := range query {
result[k] = query.Get(k)
}
json.NewEncoder(w).Encode(result)
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
AddQuery("keep", "yes").
AddQuery("delete", "no").
DeleteQuery("delete")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
var result map[string]string
resp.Body().JSON(&result)
if _, exists := result["delete"]; exists {
t.Error("delete query should not exist")
}
if result["keep"] != "yes" {
t.Errorf("keep = %v; want yes", result["keep"])
}
}

427
request.go Normal file
View File

@ -0,0 +1,427 @@
package starnet
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
)
// Request HTTP 请求
type Request struct {
ctx context.Context
execCtx context.Context // 执行时的 context注入了配置
cancel context.CancelFunc
url string
method string
err error // 累积的错误
config *RequestConfig
client *Client
httpClient *http.Client
httpReq *http.Request
retry *retryPolicy
applied bool // 是否已应用配置
doRaw bool // 是否使用原始请求(不修改)
autoFetch bool // 是否自动获取响应体
}
// newRequest 创建新请求(内部使用)
func newRequest(ctx context.Context, urlStr string, method string, opts ...RequestOpt) (*Request, error) {
if method == "" {
method = http.MethodGet
}
method = strings.ToUpper(method)
// 创建 http.Request
httpReq, err := http.NewRequestWithContext(ctx, method, urlStr, nil)
if err != nil {
return nil, wrapError(err, "create http request")
}
// 初始化配置
config := &RequestConfig{
Network: NetworkConfig{
DialTimeout: DefaultDialTimeout,
Timeout: DefaultTimeout,
},
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
}
// 设置默认 User-Agent
config.Headers.Set("User-Agent", DefaultUserAgent)
// POST 请求默认 Content-Type
if method == http.MethodPost {
config.Headers.Set("Content-Type", ContentTypeFormURLEncoded)
}
req := &Request{
ctx: ctx,
url: urlStr,
method: method,
config: config,
httpReq: httpReq,
autoFetch: DefaultFetchRespBody,
}
// 应用选项
for _, opt := range opts {
if opt != nil {
if err := opt(req); err != nil {
req.err = err
return req, nil // 不返回错误,累积到 req.err
}
}
}
return req, nil
}
// NewRequest 创建新请求
func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
return nil, err
}
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewRequestWithContext 创建新请求(带 context
func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return nil, err
}
// 新增
if req.err != nil {
return nil, req.err
}
return req, nil
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
req, err := newRequest(context.Background(), url, method, opts...)
if err != nil {
// 返回一个带错误的请求
return &Request{
ctx: context.Background(),
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := newRequest(ctx, url, method, opts...)
if err != nil {
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
}
}
return req
}
// Clone 克隆请求
func (r *Request) Clone() *Request {
cloned := &Request{
ctx: r.ctx,
url: r.url,
method: r.method,
err: r.err,
config: r.config.Clone(),
client: r.client,
httpClient: r.httpClient,
retry: cloneRetryPolicy(r.retry),
applied: false, // 重置应用状态
doRaw: r.doRaw,
autoFetch: r.autoFetch,
}
// 重新创建 http.Request
if !r.doRaw {
cloned.httpReq, _ = http.NewRequestWithContext(cloned.ctx, cloned.method, cloned.url, nil)
} else {
cloned.httpReq = r.httpReq
}
return cloned
}
// Err 获取累积的错误
func (r *Request) Err() error {
return r.err
}
// Context 获取 context
func (r *Request) Context() context.Context {
return r.ctx
}
// SetContext 设置 context
func (r *Request) SetContext(ctx context.Context) *Request {
if r.err != nil {
return r
}
r.ctx = ctx
r.httpReq = r.httpReq.WithContext(ctx)
return r
}
// Method 获取 HTTP 方法
func (r *Request) Method() string {
return r.method
}
// SetMethod 设置 HTTP 方法
func (r *Request) SetMethod(method string) *Request {
if r.err != nil {
return r
}
method = strings.ToUpper(method)
if !validMethod(method) {
r.err = wrapError(ErrInvalidMethod, "method: %s", method)
return r
}
r.method = method
r.httpReq.Method = method
return r
}
// URL 获取 URL
func (r *Request) URL() string {
return r.url
}
// SetURL 设置 URL
func (r *Request) SetURL(urlStr string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set URL when using raw request")
return r
}
u, err := url.Parse(urlStr)
if err != nil {
r.err = wrapError(ErrInvalidURL, "url: %s", urlStr)
return r
}
r.url = urlStr
u.Host = removeEmptyPort(u.Host)
r.httpReq.Host = u.Host
r.httpReq.URL = u
// 更新 TLS ServerName
if r.config.TLS.Config != nil {
r.config.TLS.Config.ServerName = u.Hostname()
}
return r
}
// RawRequest 获取底层 http.Request
func (r *Request) RawRequest() *http.Request {
return r.httpReq
}
// SetRawRequest 设置底层 http.Request启用原始模式
func (r *Request) SetRawRequest(httpReq *http.Request) *Request {
if r.err != nil {
return r
}
r.httpReq = httpReq
r.doRaw = true
if httpReq == nil {
r.err = fmt.Errorf("httpReq cannot be nil")
return r
}
return r
}
// EnableRawMode 启用原始模式(不修改请求)
func (r *Request) EnableRawMode() *Request {
r.doRaw = true
return r
}
// DisableRawMode 禁用原始模式
func (r *Request) DisableRawMode() *Request {
r.doRaw = false
return r
}
// SetAutoFetch 设置是否自动获取响应体
func (r *Request) SetAutoFetch(auto bool) *Request {
r.autoFetch = auto
return r
}
// HTTPClient 获取底层 http.Client只读
func (r *Request) HTTPClient() (*http.Client, error) {
if r.err != nil {
return nil, r.err
}
if r.httpClient != nil {
return r.httpClient, nil
}
// 如果还没构建,先准备
if err := r.prepare(); err != nil {
return nil, err
}
return r.httpClient, nil
}
// Client 获取关联的 Client只读
func (r *Request) Client() *Client {
return r.client
}
// Do 执行请求
func (r *Request) Do() (*Response, error) {
// 检查累积的错误
if r.err != nil {
return nil, r.err
}
if r.hasRetryPolicy() {
return r.doWithRetry()
}
return r.doOnce()
}
func (r *Request) doOnce() (*Response, error) {
// 准备请求
if err := r.prepare(); err != nil {
return nil, wrapError(err, "prepare request")
}
// 执行请求
httpResp, err := r.httpClient.Do(r.httpReq)
if err != nil {
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return &Response{
Response: &http.Response{},
request: r,
httpClient: r.httpClient,
body: &Body{},
}, wrapError(err, "do request")
}
rawBody := httpResp.Body
if r.cancel != nil {
rawBody = &cancelReadCloser{
ReadCloser: httpResp.Body,
cancel: r.cancel,
}
}
// 创建响应
resp := &Response{
Response: httpResp,
request: r,
httpClient: r.httpClient,
cancel: r.cancel,
body: &Body{
raw: rawBody,
maxBytes: r.config.MaxRespBodyBytes,
},
}
r.cancel = nil
// 自动获取响应体
if r.autoFetch {
if err := resp.body.readAll(); err != nil {
_ = resp.Close()
return resp, err
}
}
return resp, nil
}
// Get 发送 GET 请求
func (r *Request) Get() (*Response, error) {
return r.SetMethod(http.MethodGet).Do()
}
// Post 发送 POST 请求
func (r *Request) Post() (*Response, error) {
return r.SetMethod(http.MethodPost).Do()
}
// Put 发送 PUT 请求
func (r *Request) Put() (*Response, error) {
return r.SetMethod(http.MethodPut).Do()
}
// Delete 发送 DELETE 请求
func (r *Request) Delete() (*Response, error) {
return r.SetMethod(http.MethodDelete).Do()
}
// Head 发送 HEAD 请求
func (r *Request) Head() (*Response, error) {
return r.SetMethod(http.MethodHead).Do()
}
// Patch 发送 PATCH 请求
func (r *Request) Patch() (*Response, error) {
return r.SetMethod(http.MethodPatch).Do()
}
// Options 发送 OPTIONS 请求
func (r *Request) Options() (*Response, error) {
return r.SetMethod(http.MethodOptions).Do()
}
// Trace 发送 TRACE 请求
func (r *Request) Trace() (*Response, error) {
return r.SetMethod(http.MethodTrace).Do()
}
// Connect 发送 CONNECT 请求
func (r *Request) Connect() (*Response, error) {
return r.SetMethod(http.MethodConnect).Do()
}

448
request_body.go Normal file
View File

@ -0,0 +1,448 @@
package starnet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)
// SetBody 设置请求体(字节)
func (r *Request) SetBody(body []byte) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return r
}
// SetBodyReader 设置请求体Reader
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return r
}
// SetBodyString 设置请求体(字符串)
func (r *Request) SetBodyString(body string) *Request {
return r.SetBody([]byte(body))
}
// SetJSON 设置 JSON 请求体
func (r *Request) SetJSON(v interface{}) *Request {
if r.err != nil {
return r
}
data, err := json.Marshal(v)
if err != nil {
r.err = wrapError(err, "marshal json")
return r
}
return r.SetContentType(ContentTypeJSON).SetBody(data)
}
// SetFormData 设置表单数据(覆盖)
func (r *Request) SetFormData(data map[string][]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData = cloneStringMapSlice(data)
return r
}
// AddFormData 添加表单数据
func (r *Request) AddFormData(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return r
}
// AddFormDataMap 批量添加表单数据
func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range data {
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
}
return r
}
// AddFile 添加文件(从路径)
func (r *Request) AddFile(formName, filePath string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithName 添加文件(指定文件名)
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithType 添加文件(指定 MIME 类型)
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: fileType,
})
return r
}
// AddFileStream 添加文件流
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: fileType,
})
return r
}
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err // ← 失败时不设置 applied
}
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
// 原始模式不修改请求内容
if !r.doRaw {
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
}
execCtx := r.ctx
if !r.doRaw {
// raw 模式下不注入请求级网络配置,只应用 context/超时。
execCtx = injectRequestConfig(execCtx, r.config)
}
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
if r.config.Network.Timeout > 0 {
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
r.execCtx = execCtx
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
// 优先使用请求关联的 Client
if r.client != nil {
return r.client.HTTPClient(), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
// 默认全局 client
return DefaultHTTPClient(), nil
}

282
request_config.go Normal file
View File

@ -0,0 +1,282 @@
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 为本次请求注入 context 超时
// timeout = 0: 不额外设置请求总超时
// timeout < 0: 禁用 starnet 默认总超时
func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.Timeout = timeout
return r
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.DialTimeout = timeout
return r
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
if r.err != nil {
return r
}
r.config.Network.Proxy = proxy
return r
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
if r.err != nil {
return r
}
r.config.Network.DialFunc = fn
return r
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
if r.err != nil {
return r
}
r.config.TLS.Config = tlsConfig
return r
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
if r.err != nil {
return r
}
r.config.TLS.SkipVerify = skip
return r
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
if r.err != nil {
return r
}
// 验证 IP 格式
for _, ip := range ips {
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
}
r.config.DNS.CustomIP = ips
return r
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return r
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
if r.err != nil {
return r
}
// 验证 DNS 服务器格式
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
}
r.config.DNS.CustomDNS = dnsServers
return r
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return r
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
if r.err != nil {
return r
}
r.config.DNS.LookupFunc = fn
return r
}
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
if r.err != nil {
return r
}
r.config.BasicAuth = [2]string{username, password}
return r
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
if r.err != nil {
return r
}
r.config.ContentLength = length
return r
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
return r
}
r.config.AutoCalcContentLength = auto
return r
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
if r.err != nil {
return r
}
r.config.Transport = transport
r.config.CustomTransport = true
return r
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
if r.err != nil {
return r
}
r.config.UploadProgress = fn
return r
}
// SetMaxRespBodyBytes 设置响应体最大读取字节数(<=0 表示不限制)
func (r *Request) SetMaxRespBodyBytes(maxBytes int64) *Request {
if r.err != nil {
return r
}
if maxBytes < 0 {
r.err = fmt.Errorf("max response body bytes must be >= 0")
return r
}
r.config.MaxRespBodyBytes = maxBytes
return r
}
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = append(r.config.Queries[key], value)
return r
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = []string{value}
return r
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil {
return r
}
r.config.Queries = cloneStringMapSlice(queries)
return r
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
if r.err != nil {
return r
}
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return r
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
if r.err != nil {
return r
}
delete(r.config.Queries, key)
return r
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
if r.err != nil {
return r
}
values, ok := r.config.Queries[key]
if !ok {
return r
}
newValues := make([]string, 0, len(values))
for _, v := range values {
if v != value {
newValues = append(newValues, v)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
} else {
r.config.Queries[key] = newValues
}
return r
}

180
request_header.go Normal file
View File

@ -0,0 +1,180 @@
package starnet
import (
"net/http"
)
// SetHeader 设置 Header覆盖
func (r *Request) SetHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Set(key, value)
return r
}
// AddHeader 添加 Header
func (r *Request) AddHeader(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Add(key, value)
return r
}
// SetHeaders 设置所有 Headers覆盖
func (r *Request) SetHeaders(headers http.Header) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers = cloneHeader(headers)
return r
}
// AddHeaders 批量添加 Headers
func (r *Request) AddHeaders(headers map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range headers {
r.config.Headers.Add(k, v)
}
return r
}
// DeleteHeader 删除 Header
func (r *Request) DeleteHeader(key string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Headers.Del(key)
return r
}
// GetHeader 获取 Header
func (r *Request) GetHeader(key string) string {
return r.config.Headers.Get(key)
}
// Headers 获取所有 Headers
func (r *Request) Headers() http.Header {
return r.config.Headers
}
// SetContentType 设置 Content-Type
func (r *Request) SetContentType(contentType string) *Request {
return r.SetHeader("Content-Type", contentType)
}
// SetUserAgent 设置 User-Agent
func (r *Request) SetUserAgent(userAgent string) *Request {
return r.SetHeader("User-Agent", userAgent)
}
// SetReferer 设置 Referer
func (r *Request) SetReferer(referer string) *Request {
return r.SetHeader("Referer", referer)
}
// SetBearerToken 设置 Bearer Token
func (r *Request) SetBearerToken(token string) *Request {
return r.SetHeader("Authorization", "Bearer "+token)
}
// AddCookie 添加 Cookie
func (r *Request) AddCookie(cookie *http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = append(r.config.Cookies, cookie)
return r
}
// AddSimpleCookie 添加简单 Cookiepath 为 /
func (r *Request) AddSimpleCookie(name, value string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
// AddCookieKV 添加 Cookie指定 path
func (r *Request) AddCookieKV(name, value, path string) *Request {
return r.AddCookie(&http.Cookie{
Name: name,
Value: value,
Path: path,
})
}
// SetCookies 设置所有 Cookies覆盖
func (r *Request) SetCookies(cookies []*http.Cookie) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Cookies = cookies
return r
}
// AddCookies 批量添加 Cookies
func (r *Request) AddCookies(cookies map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for name, value := range cookies {
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
Name: name,
Value: value,
Path: "/",
})
}
return r
}
// Cookies 获取所有 Cookies
func (r *Request) Cookies() []*http.Cookie {
return r.config.Cookies
}
// ResetHeaders 重置所有 Headers
func (r *Request) ResetHeaders() *Request {
if r.err != nil {
return r
}
r.config.Headers = make(http.Header)
return r
}
// ResetCookies 重置所有 Cookies
func (r *Request) ResetCookies() *Request {
if r.err != nil {
return r
}
r.config.Cookies = []*http.Cookie{}
return r
}

View File

@ -0,0 +1,43 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestConvenienceMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
tests := []struct {
name string
method string
do func(r *Request) (*Response, error)
}{
{name: "Head", method: http.MethodHead, do: (*Request).Head},
{name: "Patch", method: http.MethodPatch, do: (*Request).Patch},
{name: "Options", method: http.MethodOptions, do: (*Request).Options},
{name: "Trace", method: http.MethodTrace, do: (*Request).Trace},
{name: "Connect", method: http.MethodConnect, do: (*Request).Connect},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := NewSimpleRequest(server.URL, http.MethodGet)
resp, err := tt.do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if got := resp.Header.Get("X-Method"); got != tt.method {
t.Fatalf("method=%s want=%s", got, tt.method)
}
})
}
}

172
request_test.go Normal file
View File

@ -0,0 +1,172 @@
package starnet
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewSimpleRequest(t *testing.T) {
tests := []struct {
name string
url string
method string
expectErr bool
}{
{
name: "valid GET request",
url: "https://example.com",
method: "GET",
expectErr: false,
},
{
name: "valid POST request",
url: "https://example.com",
method: "POST",
expectErr: false,
},
{
name: "invalid URL",
url: "://invalid",
method: "GET",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.url, tt.method)
if tt.expectErr {
if err == nil && req.Err() == nil {
t.Errorf("NewRequest() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("NewRequest() unexpected error: %v", err)
}
if req.Method() != strings.ToUpper(tt.method) {
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
}
}
})
}
}
func TestRequestMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req := NewSimpleRequest(server.URL, method)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
if method != "HEAD" {
body, _ := resp.Body().String()
if body != method {
t.Errorf("Body = %v; want %v", body, method)
}
}
})
}
}
func TestRequestSetMethod(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetMethod("POST")
if req.Method() != "POST" {
t.Errorf("Method = %v; want POST", req.Method())
}
req.SetMethod("invalid method!")
if req.Err() == nil {
t.Error("SetMethod with invalid method should set error")
}
}
func TestRequestSetURL(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetURL("https://newexample.com")
if req.URL() != "https://newexample.com" {
t.Errorf("URL = %v; want https://newexample.com", req.URL())
}
req2 := NewSimpleRequest("https://example.com", "GET")
req2.SetURL("://invalid")
if req2.Err() == nil {
t.Error("SetURL with invalid URL should set error")
}
}
func TestRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Test") != "value" {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Test", "value")
// 第一次请求
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
resp.Close()
// 克隆请求
cloned := req.Clone()
cloned.SetHeader("X-Extra", "extra")
// 克隆的请求应该也能成功
resp2, err := cloned.Do()
if err != nil {
t.Fatalf("Cloned Do() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusOK {
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
}
}
func TestRequestContext(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

199
response.go Normal file
View File

@ -0,0 +1,199 @@
package starnet
import (
"bytes"
"encoding/json"
"io"
"net/http"
"sync"
)
// Response HTTP 响应
type Response struct {
*http.Response
request *Request
httpClient *http.Client
cancel func()
body *Body
}
// Body 响应体
type Body struct {
raw io.ReadCloser
data []byte
consumed bool
maxBytes int64
mu sync.Mutex
}
type cancelReadCloser struct {
io.ReadCloser
cancel func()
once sync.Once
}
func (c *cancelReadCloser) Close() error {
err := c.ReadCloser.Close()
c.once.Do(func() {
if c.cancel != nil {
c.cancel()
}
})
return err
}
// Request 获取原始请求
func (r *Response) Request() *Request {
return r.request
}
// Body 获取响应体
func (r *Response) Body() *Body {
return r.body
}
// Close 关闭响应体
func (r *Response) Close() error {
if r == nil {
return nil
}
if r.body != nil && r.body.raw != nil {
return r.body.raw.Close()
}
if r.cancel != nil {
r.cancel()
r.cancel = nil
}
return nil
}
// CloseWithClient 关闭响应体并关闭空闲连接
func (r *Response) CloseWithClient() error {
if r == nil {
return nil
}
if r.httpClient != nil {
r.httpClient.CloseIdleConnections()
}
return r.Close()
}
// readAll 读取所有数据
func (b *Body) readAll() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
return nil
}
if b.raw == nil {
b.consumed = true
return nil
}
reader := io.Reader(b.raw)
if b.maxBytes > 0 {
reader = io.LimitReader(b.raw, b.maxBytes+1)
}
data, err := io.ReadAll(reader)
if err != nil {
return wrapError(err, "read response body")
}
if b.maxBytes > 0 && int64(len(data)) > b.maxBytes {
b.consumed = true
_ = b.raw.Close()
return wrapError(ErrRespBodyTooLarge, "response body exceeds max bytes: %d > %d", len(data), b.maxBytes)
}
b.data = data
b.consumed = true
_ = b.raw.Close()
return nil
}
// Bytes 获取响应体字节
func (b *Body) Bytes() ([]byte, error) {
if err := b.readAll(); err != nil {
return nil, err
}
return b.data, nil
}
// String 获取响应体字符串
func (b *Body) String() (string, error) {
data, err := b.Bytes()
if err != nil {
return "", err
}
return string(data), nil
}
// JSON 解析 JSON 响应
func (b *Body) JSON(v interface{}) error {
data, err := b.Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, v)
}
// Reader 获取 Reader只能调用一次
func (b *Body) Reader() (io.ReadCloser, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.consumed {
if b.data != nil {
// 已读取,返回缓存数据的 Reader
return io.NopCloser(bytes.NewReader(b.data)), nil
}
return nil, ErrBodyAlreadyConsumed
}
b.consumed = true
return b.raw, nil
}
// IsConsumed 检查是否已消费
func (b *Body) IsConsumed() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.consumed
}
// Close 关闭 Body
func (b *Body) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.raw != nil {
return b.raw.Close()
}
return nil
}
// MustBytes 获取响应体字节(忽略错误,失败返回 nil
func (b *Body) MustBytes() []byte {
data, err := b.Bytes()
if err != nil {
return nil
}
return data
}
// MustString 获取响应体字符串(忽略错误,失败返回空串)
func (b *Body) MustString() string {
s, err := b.String()
if err != nil {
return ""
}
return s
}
// Unmarshal 解析 JSON 响应(兼容旧 API
func (b *Body) Unmarshal(v interface{}) error {
return b.JSON(v)
}

84
response_limit_test.go Normal file
View File

@ -0,0 +1,84 @@
package starnet
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestWithMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
resp, err := Get(server.URL, WithMaxRespBodyBytes(4))
if err != nil {
t.Fatalf("unexpected request error: %v", err)
}
defer resp.Close()
_, err = resp.Body().Bytes()
if err == nil {
t.Fatal("expected body too large error")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytes(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("1234"))
}))
defer server.Close()
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetMaxRespBodyBytes(4).
Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("unexpected read error: %v", err)
}
if body != "1234" {
t.Fatalf("body=%q want=1234", body)
}
}
func TestSetMaxRespBodyBytesWithAutoFetch(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("123456"))
}))
defer server.Close()
_, err := NewSimpleRequest(server.URL, http.MethodGet).
SetAutoFetch(true).
SetMaxRespBodyBytes(4).
Do()
if err == nil {
t.Fatal("expected body too large error with auto fetch")
}
if !errors.Is(err, ErrRespBodyTooLarge) {
t.Fatalf("expected ErrRespBodyTooLarge, got: %v", err)
}
}
func TestSetMaxRespBodyBytesInvalid(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetMaxRespBodyBytes(-1)
if req.Err() == nil {
t.Fatal("expected error for negative max bytes")
}
}
func TestWithMaxRespBodyBytesInvalid(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithMaxRespBodyBytes(-1))
if err == nil {
t.Fatal("expected error for negative max bytes")
}
}

179
response_test.go Normal file
View File

@ -0,0 +1,179 @@
package starnet
import (
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestResponseBody(t *testing.T) {
testData := "test response data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
// Test String()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
// Test multiple reads (should work because body is cached)
body2, err := resp.Body().String()
if err != nil {
t.Fatalf("Second Body().String() error: %v", err)
}
if body2 != testData {
t.Errorf("Second Body = %v; want %v", body2, testData)
}
}
func TestResponseJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"name":"John","age":30}`))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
var result struct {
Name string `json:"name"`
Age int `json:"age"`
}
err = resp.Body().JSON(&result)
if err != nil {
t.Fatalf("Body().JSON() error: %v", err)
}
if result.Name != "John" {
t.Errorf("Name = %v; want John", result.Name)
}
if result.Age != 30 {
t.Errorf("Age = %v; want 30", result.Age)
}
}
func TestResponseBytes(t *testing.T) {
testData := []byte("binary data")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(testData)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().Bytes()
if err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if string(body) != string(testData) {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseReader(t *testing.T) {
testData := "stream data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
reader, err := resp.Body().Reader()
if err != nil {
t.Fatalf("Body().Reader() error: %v", err)
}
defer reader.Close()
body, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(body) != testData {
t.Errorf("Body = %v; want %v", string(body), testData)
}
}
func TestResponseAutoFetch(t *testing.T) {
testData := "auto fetch data"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testData))
}))
defer server.Close()
// With auto fetch
resp, err := Get(server.URL, WithAutoFetch(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if !resp.Body().IsConsumed() {
t.Error("Body should be consumed with auto fetch")
}
body, _ := resp.Body().String()
if body != testData {
t.Errorf("Body = %v; want %v", body, testData)
}
}
func TestResponseStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode int
}{
{"OK", http.StatusOK},
{"Created", http.StatusCreated},
{"BadRequest", http.StatusBadRequest},
{"NotFound", http.StatusNotFound},
{"InternalServerError", http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.statusCode)
}))
defer server.Close()
resp, err := Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != tt.statusCode {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, tt.statusCode)
}
})
}
}

423
retry.go Normal file
View File

@ -0,0 +1,423 @@
package starnet
import (
"context"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"time"
)
type RetryOpt func(*retryPolicy) error
type retryPolicy struct {
maxRetries int
baseDelay time.Duration
maxDelay time.Duration
factor float64
jitter float64
idempotentOnly bool
statuses map[int]struct{}
onError func(error) bool
}
func cloneRetryPolicy(p *retryPolicy) *retryPolicy {
if p == nil {
return nil
}
cloned := &retryPolicy{
maxRetries: p.maxRetries,
baseDelay: p.baseDelay,
maxDelay: p.maxDelay,
factor: p.factor,
jitter: p.jitter,
idempotentOnly: p.idempotentOnly,
onError: p.onError,
}
if p.statuses != nil {
cloned.statuses = make(map[int]struct{}, len(p.statuses))
for code := range p.statuses {
cloned.statuses[code] = struct{}{}
}
}
return cloned
}
func defaultRetryPolicy(max int) *retryPolicy {
return &retryPolicy{
maxRetries: max,
baseDelay: 100 * time.Millisecond,
maxDelay: 2 * time.Second,
factor: 2.0,
jitter: 0.1,
idempotentOnly: true,
statuses: map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooEarly: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
},
}
}
func buildRetryPolicy(max int, opts ...RetryOpt) (*retryPolicy, error) {
if max < 0 {
return nil, fmt.Errorf("max retry must be >= 0")
}
if max == 0 {
return nil, nil
}
policy := defaultRetryPolicy(max)
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(policy); err != nil {
return nil, err
}
}
return policy, nil
}
func WithRetry(max int, opts ...RetryOpt) RequestOpt {
return func(r *Request) error {
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
return err
}
r.retry = policy
return nil
}
}
func (r *Request) SetRetry(max int, opts ...RetryOpt) *Request {
if r.err != nil {
return r
}
policy, err := buildRetryPolicy(max, opts...)
if err != nil {
r.err = err
return r
}
r.retry = policy
return r
}
func (r *Request) DisableRetry() *Request {
if r.err != nil {
return r
}
r.retry = nil
return r
}
func (r *Request) applyRetryOpt(opt RetryOpt) *Request {
if r.err != nil {
return r
}
if opt == nil {
return r
}
if r.retry == nil {
r.err = fmt.Errorf("retry policy is not enabled, call SetRetry first")
return r
}
if err := opt(r.retry); err != nil {
r.err = err
}
return r
}
func (r *Request) SetRetryBackoff(base, max time.Duration, factor float64) *Request {
return r.applyRetryOpt(WithRetryBackoff(base, max, factor))
}
func (r *Request) SetRetryJitter(ratio float64) *Request {
return r.applyRetryOpt(WithRetryJitter(ratio))
}
func (r *Request) SetRetryStatuses(codes ...int) *Request {
return r.applyRetryOpt(WithRetryStatuses(codes...))
}
func (r *Request) SetRetryIdempotentOnly(enabled bool) *Request {
return r.applyRetryOpt(WithRetryIdempotentOnly(enabled))
}
func (r *Request) SetRetryOnError(fn func(error) bool) *Request {
return r.applyRetryOpt(WithRetryOnError(fn))
}
func WithRetryBackoff(base, max time.Duration, factor float64) RetryOpt {
return func(p *retryPolicy) error {
if base < 0 {
return fmt.Errorf("retry base delay must be >= 0")
}
if max < 0 {
return fmt.Errorf("retry max delay must be >= 0")
}
if factor <= 0 {
return fmt.Errorf("retry factor must be > 0")
}
p.baseDelay = base
p.maxDelay = max
p.factor = factor
return nil
}
}
func WithRetryJitter(ratio float64) RetryOpt {
return func(p *retryPolicy) error {
if ratio < 0 || ratio > 1 {
return fmt.Errorf("retry jitter ratio must be in [0,1]")
}
p.jitter = ratio
return nil
}
}
func WithRetryStatuses(codes ...int) RetryOpt {
return func(p *retryPolicy) error {
statuses := make(map[int]struct{}, len(codes))
for _, code := range codes {
if code < 100 || code > 999 {
return fmt.Errorf("invalid retry status code: %d", code)
}
statuses[code] = struct{}{}
}
p.statuses = statuses
return nil
}
}
func WithRetryIdempotentOnly(enabled bool) RetryOpt {
return func(p *retryPolicy) error {
p.idempotentOnly = enabled
return nil
}
}
func WithRetryOnError(fn func(error) bool) RetryOpt {
return func(p *retryPolicy) error {
p.onError = fn
return nil
}
}
func (r *Request) hasRetryPolicy() bool {
return r.retry != nil && r.retry.maxRetries > 0
}
func (r *Request) doWithRetry() (*Response, error) {
policy := cloneRetryPolicy(r.retry)
if policy == nil || policy.maxRetries <= 0 {
return r.doOnce()
}
if !policy.canRetryRequest(r) {
return r.doOnce()
}
retryCtx := r.ctx
retryCancel := func() {}
if r.config.Network.Timeout > 0 {
retryCtx, retryCancel = context.WithTimeout(r.ctx, r.config.Network.Timeout)
}
defer retryCancel()
maxAttempts := policy.maxRetries + 1
var lastResp *Response
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
attemptReq, err := r.newRetryAttempt(retryCtx)
if err != nil {
return nil, wrapError(err, "build retry attempt")
}
resp, err := attemptReq.doOnce()
if resp != nil {
resp.request = r
}
if !policy.shouldRetry(resp, err, attempt, maxAttempts, retryCtx) {
return resp, err
}
lastResp = resp
lastErr = err
if lastResp != nil {
_ = lastResp.Close()
}
delay := policy.nextDelay(attempt)
if delay <= 0 {
continue
}
timer := time.NewTimer(delay)
select {
case <-retryCtx.Done():
timer.Stop()
return lastResp, wrapError(retryCtx.Err(), "retry context done")
case <-timer.C:
}
}
return lastResp, lastErr
}
func (r *Request) newRetryAttempt(ctx context.Context) (*Request, error) {
attempt := r.Clone()
attempt.retry = nil
attempt.cancel = nil
attempt.applied = false
attempt.execCtx = nil
attempt.ctx = ctx
// 共享总超时上下文后,避免每次 attempt 再创建一次 timeout context。
if attempt.config != nil && attempt.config.Network.Timeout > 0 {
attempt.config.Network.Timeout = 0
}
if !attempt.doRaw {
attempt.httpReq = attempt.httpReq.WithContext(ctx)
return attempt, nil
}
if r.httpReq == nil {
return nil, fmt.Errorf("http request is nil")
}
raw := r.httpReq.Clone(ctx)
if r.httpReq.GetBody != nil {
body, err := r.httpReq.GetBody()
if err != nil {
return nil, wrapError(err, "get raw request body")
}
raw.Body = body
} else if r.httpReq.Body != nil && r.httpReq.Body != http.NoBody {
return nil, fmt.Errorf("raw request body is not replayable")
}
attempt.httpReq = raw
return attempt, nil
}
func (p *retryPolicy) canRetryRequest(r *Request) bool {
if p.idempotentOnly && !isIdempotentMethod(r.method) {
return false
}
return isReplayableRequest(r)
}
func isIdempotentMethod(method string) bool {
switch method {
case http.MethodGet, http.MethodHead, http.MethodPut, http.MethodDelete, http.MethodOptions, http.MethodTrace:
return true
default:
return false
}
}
func isReplayableRequest(r *Request) bool {
if r == nil {
return false
}
if r.doRaw {
if r.httpReq == nil {
return false
}
if r.httpReq.Body == nil || r.httpReq.Body == http.NoBody {
return true
}
return r.httpReq.GetBody != nil
}
if r.config == nil {
return false
}
// Reader / stream body 通常不可重放,保守地不重试。
if r.config.Body.Reader != nil {
return false
}
for _, f := range r.config.Body.Files {
if f.FileData != nil || f.FilePath == "" {
return false
}
}
return true
}
func (p *retryPolicy) shouldRetry(resp *Response, err error, attempt, maxAttempts int, ctx context.Context) bool {
if attempt >= maxAttempts-1 {
return false
}
if ctx != nil && ctx.Err() != nil {
return false
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if p.onError != nil {
return p.onError(err)
}
return isRetryableError(err)
}
if resp == nil || resp.Response == nil {
return false
}
_, ok := p.statuses[resp.StatusCode]
return ok
}
func isRetryableError(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return true
}
if netErr.Temporary() {
return true
}
}
return errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
func (p *retryPolicy) nextDelay(attempt int) time.Duration {
if p.baseDelay <= 0 {
return 0
}
delay := time.Duration(float64(p.baseDelay) * math.Pow(p.factor, float64(attempt)))
if p.maxDelay > 0 && delay > p.maxDelay {
delay = p.maxDelay
}
if p.jitter <= 0 {
return delay
}
low := 1 - p.jitter
if low < 0 {
low = 0
}
high := 1 + p.jitter
scale := low + rand.Float64()*(high-low)
return time.Duration(float64(delay) * scale)
}

298
retry_test.go Normal file
View File

@ -0,0 +1,298 @@
package starnet
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithRetrySmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
resp, err := Get(s.URL,
WithRetry(2,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
),
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 3 {
t.Fatalf("hits=%d want=3", hits)
}
}
func TestWithRetryResponseRequestPointerStable(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.Request() != req {
t.Fatal("response request pointer should point to original request")
}
}
func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyReader(strings.NewReader("payload")).
SetRetry(3,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable)
}
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if atomic.LoadInt32(&hits) != 2 {
t.Fatalf("hits=%d want=2", hits)
}
}
func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw")))
if rawReq.GetBody != nil {
t.Fatal("raw request GetBody should be nil in this test")
}
req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)).
SetRetry(2,
WithRetryIdempotentOnly(false),
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if atomic.LoadInt32(&hits) != 1 {
t.Fatalf("hits=%d want=1", hits)
}
}
func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&hits, 1)
time.Sleep(80 * time.Millisecond)
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetTimeout(120*time.Millisecond).
SetRetry(3,
WithRetryBackoff(0, 0, 1),
WithRetryJitter(0),
)
_, err := req.Do()
if err == nil {
t.Fatal("expected timeout error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context deadline exceeded, got: %v", err)
}
if h := atomic.LoadInt32(&hits); h > 2 {
t.Fatalf("hits=%d want<=2 under tight timeout budget", h)
}
}
func TestSetRetryInvalidMax(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1)
if req.Err() == nil {
t.Fatal("expected error for negative retry max")
}
}
func TestSetRetrySeriesSmokeGet(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
if n <= 2 {
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodGet).
SetRetry(2).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryStatuses(http.StatusTooManyRequests)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 3 {
t.Fatalf("hits=%d want=3", h)
}
}
func TestSetRetryIdempotentOnlyWithPost(t *testing.T) {
var hits int32
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := atomic.AddInt32(&hits, 1)
_, _ = io.Copy(io.Discard, r.Body)
if n == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
req := NewSimpleRequest(s.URL, http.MethodPost).
SetBodyString("hello").
SetRetry(1).
SetRetryIdempotentOnly(false).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0)
resp, err := req.Do()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK)
}
if h := atomic.LoadInt32(&hits); h != 2 {
t.Fatalf("hits=%d want=2", h)
}
}
func TestSetRetryOnErrorOverridesDefault(t *testing.T) {
var dials int32
dialErr := errors.New("dial failed")
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.AddInt32(&dials, 1)
return nil, dialErr
}).
SetRetry(1).
SetRetryBackoff(0, 0, 1).
SetRetryJitter(0).
SetRetryOnError(func(err error) bool {
return true
})
_, err := req.Do()
if err == nil {
t.Fatal("expected error")
}
if h := atomic.LoadInt32(&dials); h != 2 {
t.Fatalf("dial attempts=%d want=2", h)
}
}
func TestSetRetryOptionRequireEnableRetry(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2)
if req.Err() == nil {
t.Fatal("expected error when setting retry options before SetRetry")
}
if !strings.Contains(req.Err().Error(), "call SetRetry first") {
t.Fatalf("unexpected error: %v", req.Err())
}
}

115
timeout_refactor_test.go Normal file
View File

@ -0,0 +1,115 @@
package starnet
import (
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeoutDoesNotMutateClientTimeout(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
client := NewClientNoErr()
baseTimeout := client.HTTPClient().Timeout
_, err := client.Get(s.URL, WithTimeout(80*time.Millisecond))
if err == nil {
t.Fatal("expected request timeout error")
}
if client.HTTPClient().Timeout != baseTimeout {
t.Fatalf("client timeout mutated: got=%v want=%v", client.HTTPClient().Timeout, baseTimeout)
}
resp, err := client.Get(s.URL, WithTimeout(400*time.Millisecond))
if err != nil {
t.Fatalf("second request should succeed with larger timeout: %v", err)
}
defer resp.Close()
}
func TestRequestTimeoutNoLingeringConnDeadline(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer s.Close()
client := NewClientNoErr()
resp1, err := client.Post(s.URL, WithBodyString("first"), WithTimeout(120*time.Millisecond))
if err != nil {
t.Fatalf("first request should succeed: %v", err)
}
_ = resp1.Close()
// 如果请求超时依赖连接级绝对 deadline经过该等待后复用连接会出现误超时。
time.Sleep(220 * time.Millisecond)
resp2, err := client.Post(s.URL, WithBodyString("second"), WithTimeout(1*time.Second))
if err != nil {
t.Fatalf("second request should not be affected by previous timeout window: %v", err)
}
_ = resp2.Close()
}
func TestConnReadDeadlineTimeoutAndRecover(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
time.Sleep(180 * time.Millisecond)
_, _ = conn.Write([]byte("x"))
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial error: %v", err)
}
defer c.Close()
if err := c.SetReadDeadline(time.Now().Add(60 * time.Millisecond)); err != nil {
t.Fatalf("set read deadline error: %v", err)
}
buf := make([]byte, 1)
_, err = c.Read(buf)
if err == nil {
t.Fatal("expected read timeout error")
}
if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
t.Fatalf("expected net timeout error, got: %v", err)
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
t.Fatalf("clear read deadline error: %v", err)
}
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read after clearing deadline should succeed: %v", err)
}
if string(buf) != "x" {
t.Fatalf("unexpected payload: %q", string(buf))
}
<-done
}

66
timeout_test.go Normal file
View File

@ -0,0 +1,66 @@
package starnet
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Should timeout
req := NewSimpleRequest(server.URL, "GET").SetTimeout(100 * time.Millisecond)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
// Should succeed
req2 := NewSimpleRequest(server.URL, "GET").SetTimeout(300 * time.Millisecond)
resp, err := req2.Do()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if resp != nil {
resp.Close()
}
}
func TestRequestDialTimeout(t *testing.T) {
// Use a non-routable IP to test dial timeout
req := NewSimpleRequest("http://192.0.2.1:80", "GET").
SetDialTimeout(100 * time.Millisecond)
start := time.Now()
_, err := req.Do()
elapsed := time.Since(start)
if err == nil {
t.Error("Expected dial timeout error, got nil")
}
// Should timeout within reasonable time (not wait forever)
if elapsed > 2*time.Second {
t.Errorf("Dial timeout took too long: %v", elapsed)
}
}
func TestClientTimeout(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr(WithTimeout(100 * time.Millisecond))
_, err := client.Get(server.URL)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}

229
tls_test.go Normal file
View File

@ -0,0 +1,229 @@
package starnet
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRequestSkipTLSVerify(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
// Without skip verify (should fail)
req := NewSimpleRequest(server.URL, "GET")
_, err := req.Do()
if err == nil {
t.Error("Expected TLS error without skip verify, got nil")
}
// With skip verify (should succeed)
req2 := NewSimpleRequest(server.URL, "GET").SetSkipTLSVerify(true)
resp, err := req2.Do()
if err != nil {
t.Fatalf("Do() with skip verify error: %v", err)
}
defer resp.Close()
body, _ := resp.Body().String()
if body != "OK" {
t.Errorf("Body = %v; want OK", body)
}
}
func TestRequestCustomTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
}
req := NewSimpleRequest(server.URL, "GET").SetTLSConfig(tlsConfig)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestClientDefaultTLSConfig(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(true)
resp, err := client.Get(server.URL)
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestLevelTLSOverride(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
// Client level: skip verify = false
client := NewClientNoErr()
client.SetDefaultSkipTLSVerify(false)
// Request level: skip verify = true (should override)
resp, err := client.Get(server.URL, WithSkipTLSVerify(true))
if err != nil {
t.Fatalf("Get() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
}
func TestRequestTls(t *testing.T) {
resp, err := NewSimpleRequest("https://www.b612.me", "GET").Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
client, err := NewClient()
if err != nil {
t.Fatalf("NewClient() error: %v", err)
}
resp, err = client.NewSimpleRequest("https://www.b612.me", "GET",
WithHeader("hello", "world"),
WithContext(context.Background()),
WithBearerToken("ddddddd")).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
t.Logf("Response: %v", resp.Body().MustString())
}
func TestTLSWithProxyPath(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Log(resp.Status)
}
func TestTLSWithProxyBug(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 关键:使用 WithProxy 触发 needsDynamicTransport
// 即使 proxy 是空串或无效地址,只要设置了就会走 buildDynamicTransport 分支
req, err := client.NewRequest("https://registry-1.docker.io/v2/", "GET",
WithTimeout(10*time.Second),
WithProxy("http://127.0.0.1:29992"), // 随便一个 proxy 地址,触发动态 transport
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前会报tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
// 更精准的复现:直接测试有问题的分支
func TestTLSDialWithoutServerName(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// 使用 WithCustomIP 也能触发 defaultDialTLSFunc
req, err := client.NewRequest("https://www.google.com", "GET",
WithTimeout(10*time.Second),
WithCustomIP([]string{"142.250.185.46"}), // Google 的一个 IP
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}
// 最小复现:只要触发 needsDynamicTransport 即可
func TestMinimalTLSBug(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
// WithDialTimeout 也会触发动态 transport
req, err := client.NewRequest("https://www.baidu.com", "GET",
WithDialTimeout(5*time.Second),
)
if err != nil {
t.Fatal(err)
}
resp, err := req.Do()
if err != nil {
// 修复前必现tls handshake: tls: either ServerName or InsecureSkipVerify must be specified
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
t.Logf("Status: %s", resp.Status)
}

90
tlsconfig.go Normal file
View File

@ -0,0 +1,90 @@
package starnet
import (
"crypto/tls"
"net"
"time"
)
// GetConfigForClientFunc selects TLS config by hostname/SNI.
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
// ClientHelloMeta carries sniffed TLS metadata and connection context.
type ClientHelloMeta struct {
ServerName string
LocalAddr net.Addr
RemoteAddr net.Addr
SupportedProtos []string
SupportedVersions []uint16
CipherSuites []uint16
}
// Clone returns a detached copy safe for callers to mutate.
func (m *ClientHelloMeta) Clone() *ClientHelloMeta {
if m == nil {
return nil
}
out := *m
if m.SupportedProtos != nil {
out.SupportedProtos = append([]string(nil), m.SupportedProtos...)
}
if m.SupportedVersions != nil {
out.SupportedVersions = append([]uint16(nil), m.SupportedVersions...)
}
if m.CipherSuites != nil {
out.CipherSuites = append([]uint16(nil), m.CipherSuites...)
}
return &out
}
// GetConfigForClientHelloFunc selects TLS config by sniffed TLS metadata.
type GetConfigForClientHelloFunc func(hello *ClientHelloMeta) (*tls.Config, error)
// ListenerConfig controls listener behavior.
type ListenerConfig struct {
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
BaseTLSConfig *tls.Config
// GetConfigForClient selects TLS config for a hostname/SNI.
// Deprecated: prefer GetConfigForClientHello for richer context.
GetConfigForClient GetConfigForClientFunc
// GetConfigForClientHello selects TLS config for sniffed TLS metadata.
GetConfigForClientHello GetConfigForClientHelloFunc
// AllowNonTLS allows plain TCP fallback.
AllowNonTLS bool
// SniffTimeout bounds protocol sniffing time. 0 means no timeout.
SniffTimeout time.Duration
// MaxClientHelloBytes limits buffered sniff data.
// If <= 0, default 64KiB.
MaxClientHelloBytes int
// Logger is optional.
Logger Logger
}
// DefaultListenerConfig returns a conservative default config.
func DefaultListenerConfig() ListenerConfig {
return ListenerConfig{
AllowNonTLS: false,
SniffTimeout: 5 * time.Second,
MaxClientHelloBytes: 64 * 1024,
}
}
// TLSDefaults returns a TLS config baseline.
// Caller should set Certificates / GetCertificate as needed.
func TLSDefaults() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// DialConfig controls dialing behavior.
type DialConfig struct {
Timeout time.Duration
LocalAddr net.Addr
}

833
tlssniffer.go Normal file
View File

@ -0,0 +1,833 @@
package starnet
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
// replayConn replays buffered bytes first, then reads from live conn.
type replayConn struct {
reader io.Reader
conn net.Conn
}
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
return &replayConn{
reader: io.MultiReader(buffered, conn),
conn: conn,
}
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c *replayConn) Close() error { return c.conn.Close() }
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// SniffResult describes protocol sniffing result.
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
}
// Sniffer detects protocol and metadata from initial bytes.
type Sniffer interface {
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
}
// TLSSniffer is the default sniffer implementation.
type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
}
if isTLS {
out.ClientHello = meta
}
return out, nil
}
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
}
header, complete := readTLSRecordHeader(r, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, nextOK := readTLSRecordHeader(r, buf)
if len(nextHeader) < 5 || !nextOK {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
if !nextBodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(r, buf, 5)
}
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
if n <= 0 {
return nil, true
}
tmp := make([]byte, n)
readN, err := io.ReadFull(r, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for i := 0; i+1 < cipherSuitesLen; i += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
}
}
// Conn wraps net.Conn with lazy protocol initialization.
type Conn struct {
net.Conn
once sync.Once
initErr error
closeOnce sync.Once
isTLS bool
tlsConn *tls.Conn
plainConn net.Conn
clientHello *ClientHelloMeta
baseTLSConfig *tls.Config
getConfigForClient GetConfigForClientFunc
getConfigForClientHello GetConfigForClientHelloFunc
allowNonTLS bool
sniffer Sniffer
sniffTimeout time.Duration
maxClientHello int
logger Logger
stats *Stats
skipSniff bool
}
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{
Conn: raw,
plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient,
getConfigForClientHello: cfg.GetConfigForClientHello,
allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger,
stats: stats,
}
}
func (c *Conn) init() {
c.once.Do(func() {
if c.skipSniff {
return
}
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
c.isTLS = false
return
}
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
}
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Time{})
}
if err != nil {
c.initErr = errors.Join(ErrTLSSniffFailed, err)
c.failSniff(err)
return
}
c.isTLS = res.IsTLS
c.clientHello = res.ClientHello
if c.isTLS {
if c.stats != nil {
c.stats.incTLSDetected()
}
tlsCfg, errCfg := c.selectTLSConfig()
if errCfg != nil {
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
c.failTLSConfigSelection(errCfg)
return
}
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
c.tlsConn = tls.Server(rc, tlsCfg)
return
}
if c.stats != nil {
c.stats.incPlainDetected()
}
if !c.allowNonTLS {
c.initErr = ErrNonTLSNotAllowed
c.failPlainRejected()
return
}
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
})
}
func (c *Conn) failAndClose(format string, v ...interface{}) {
if c.logger != nil {
c.logger.Printf("starnet: "+format, v...)
}
_ = c.Close()
}
func (c *Conn) failSniff(err error) {
if c.stats != nil {
c.stats.incSniffFailures()
}
c.failAndClose("tls sniff failed: %v", err)
}
func (c *Conn) failTLSConfigSelection(err error) {
if c.stats != nil {
c.stats.incTLSConfigFailures()
}
c.failAndClose("tls config selection failed: %v", err)
}
func (c *Conn) failPlainRejected() {
if c.stats != nil {
c.stats.incPlainRejected()
}
c.failAndClose("plain tcp rejected")
}
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
var selected *tls.Config
if c.getConfigForClientHello != nil {
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
}
}
if selected == nil && c.getConfigForClient != nil {
cfg, err := c.getConfigForClient(c.serverName())
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
}
}
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
if composed != nil {
return composed, nil
}
return nil, ErrNoTLSConfig
}
// Hostname returns sniffed SNI hostname (if any).
func (c *Conn) Hostname() string {
c.init()
return c.serverName()
}
// ClientHello returns sniffed TLS metadata (if any).
func (c *Conn) ClientHello() *ClientHelloMeta {
c.init()
return c.clientHello.Clone()
}
func (c *Conn) serverName() string {
if c.clientHello == nil {
return ""
}
return c.clientHello.ServerName
}
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
applyServerTLSOverrides(out, selected)
return out
}
func applyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for k, v := range src.NameToCertificate {
m[k] = v
}
dst.NameToCertificate = m
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
}
func (c *Conn) IsTLS() bool {
c.init()
return c.initErr == nil && c.isTLS
}
func (c *Conn) TLSConn() (*tls.Conn, error) {
c.init()
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS || c.tlsConn == nil {
return nil, ErrNotTLS
}
return c.tlsConn, nil
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.plainConn.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.plainConn.Write(b)
}
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
if c.tlsConn != nil {
err = c.tlsConn.Close()
} else {
err = c.Conn.Close()
}
if c.stats != nil {
c.stats.incClosed()
}
})
return err
}
func (c *Conn) SetDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.plainConn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.plainConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.plainConn.SetWriteDeadline(t)
}
// Listener wraps net.Listener and returns starnet.Conn from Accept.
type Listener struct {
net.Listener
mu sync.RWMutex
cfg ListenerConfig
stats Stats
}
// Listen creates a plain listener config (no TLS detection).
func Listen(network, address string) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
return &Listener{Listener: ln, cfg: cfg}, nil
}
// ListenWithConfig creates a listener with full config.
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenWithListenConfig creates listener using net.ListenConfig.
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenTLS creates TLS listener from cert/key paths.
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = allowNonTLS
cfg.BaseTLSConfig = TLSDefaults()
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
return ListenWithConfig(network, address, cfg)
}
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
out := DefaultListenerConfig()
out.AllowNonTLS = cfg.AllowNonTLS
out.SniffTimeout = cfg.SniffTimeout
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
out.BaseTLSConfig = cfg.BaseTLSConfig
out.GetConfigForClient = cfg.GetConfigForClient
out.GetConfigForClientHello = cfg.GetConfigForClientHello
out.Logger = cfg.Logger
if out.MaxClientHelloBytes <= 0 {
out.MaxClientHelloBytes = 64 * 1024
}
return out
}
// SetConfig atomically replaces listener config for new accepted connections.
func (l *Listener) SetConfig(cfg ListenerConfig) {
l.mu.Lock()
l.cfg = normalizeConfig(cfg)
l.mu.Unlock()
}
// Config returns a copy of current config.
func (l *Listener) Config() ListenerConfig {
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return cfg
}
// Stats returns current counters snapshot.
func (l *Listener) Stats() StatsSnapshot {
return l.stats.Snapshot()
}
func (l *Listener) Accept() (net.Conn, error) {
raw, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.stats.incAccepted()
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return newConn(raw, cfg, &l.stats), nil
}
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
type result struct {
c net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := l.Accept()
ch <- result{c: c, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
return r.c, r.err
}
}
// Dial creates a plain TCP starnet.Conn.
func Dial(network, address string) (*Conn, error) {
raw, err := net.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialWithConfig dials with net.Dialer options.
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
d := net.Dialer{
Timeout: dc.Timeout,
LocalAddr: dc.LocalAddr,
}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialTLSWithConfig creates a TLS client connection wrapper.
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
tc := tls.Client(raw, tlsCfg)
return &Conn{
Conn: raw,
plainConn: raw,
isTLS: true,
tlsConn: tc,
initErr: nil,
allowNonTLS: false,
skipSniff: true,
}, nil
}
// DialTLS creates TLS client conn from cert/key paths.
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := TLSDefaults()
cfg.Certificates = []tls.Certificate{cert}
return DialTLSWithConfig(network, address, cfg, 0)
}
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
if listener == nil {
return nil, ErrNilConn
}
return &Listener{
Listener: listener,
cfg: normalizeConfig(cfg),
}, nil
}

1210
tlssniffer_test.go Normal file

File diff suppressed because it is too large Load Diff

55
tlsstats.go Normal file
View File

@ -0,0 +1,55 @@
package starnet
import "sync/atomic"
// StatsSnapshot is a read-only copy of runtime counters.
type StatsSnapshot struct {
Accepted uint64
TLSDetected uint64
PlainDetected uint64
InitFailures uint64
SniffFailures uint64
TLSConfigFailures uint64
PlainRejected uint64
Closed uint64
}
// Stats provides lock-free counters.
type Stats struct {
accepted uint64
tlsDetected uint64
plainDetected uint64
initFailures uint64
sniffFailures uint64
tlsConfigFailures uint64
plainRejected uint64
closed uint64
}
func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) }
func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
func (s *Stats) incSniffFailures() { atomic.AddUint64(&s.sniffFailures, 1); s.incInitFailures() }
func (s *Stats) incTLSConfigFailures() { atomic.AddUint64(&s.tlsConfigFailures, 1); s.incInitFailures() }
func (s *Stats) incPlainRejected() { atomic.AddUint64(&s.plainRejected, 1); s.incInitFailures() }
// Snapshot returns a stable view of counters.
func (s *Stats) Snapshot() StatsSnapshot {
return StatsSnapshot{
Accepted: atomic.LoadUint64(&s.accepted),
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
PlainDetected: atomic.LoadUint64(&s.plainDetected),
InitFailures: atomic.LoadUint64(&s.initFailures),
SniffFailures: atomic.LoadUint64(&s.sniffFailures),
TLSConfigFailures: atomic.LoadUint64(&s.tlsConfigFailures),
PlainRejected: atomic.LoadUint64(&s.plainRejected),
Closed: atomic.LoadUint64(&s.closed),
}
}
// Logger is a minimal logging abstraction.
type Logger interface {
Printf(format string, v ...interface{})
}

97
transport.go Normal file
View File

@ -0,0 +1,97 @@
package starnet
import (
"net/http"
"net/url"
"sync"
"time"
)
// Transport 自定义 Transport支持请求级配置
type Transport struct {
base *http.Transport
mu sync.RWMutex
}
// RoundTrip 实现 http.RoundTripper 接口
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// 确保 base 已初始化
if t.base == nil {
t.mu.Lock()
if t.base == nil {
t.base = &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
t.mu.Unlock()
}
// 提取请求级别的配置
reqCtx := getRequestContext(req.Context())
// 优先级1完全自定义的 transport
if reqCtx.Transport != nil {
return reqCtx.Transport.RoundTrip(req)
}
// 优先级2需要动态配置
if needsDynamicTransport(reqCtx) {
dynamicTransport := t.buildDynamicTransport(reqCtx)
return dynamicTransport.RoundTrip(req)
}
// 优先级3使用基础 transport
t.mu.RLock()
defer t.mu.RUnlock()
return t.base.RoundTrip(req)
}
// buildDynamicTransport 构建动态 Transport
func (t *Transport) buildDynamicTransport(rc *RequestContext) *http.Transport {
t.mu.RLock()
transport := t.base.Clone()
t.mu.RUnlock()
// 应用 TLS 配置(即使为 nil 也要检查 SkipVerify
if rc.TLSConfig != nil {
transport.TLSClientConfig = rc.TLSConfig
}
// 应用代理配置
if rc.Proxy != "" {
proxyURL, err := url.Parse(rc.Proxy)
if err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
}
// 应用自定义 Dial 函数
if rc.DialFn != nil {
transport.DialContext = rc.DialFn
} else if len(rc.CustomIP) > 0 || len(rc.CustomDNS) > 0 || rc.DialTimeout > 0 || rc.LookupIPFn != nil {
// 使用默认 Dial 函数(会从 context 读取配置)
transport.DialContext = defaultDialFunc
transport.DialTLSContext = defaultDialTLSFunc
}
return transport
}
// Base 获取基础 Transport
func (t *Transport) Base() *http.Transport {
t.mu.RLock()
defer t.mu.RUnlock()
return t.base
}
// SetBase 设置基础 Transport
func (t *Transport) SetBase(base *http.Transport) {
t.mu.Lock()
t.base = base
t.mu.Unlock()
}

133
types.go Normal file
View File

@ -0,0 +1,133 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
)
// HTTP Content-Type 常量
const (
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
ContentTypeFormData = "multipart/form-data"
ContentTypeJSON = "application/json"
ContentTypeXML = "application/xml"
ContentTypePlain = "text/plain"
ContentTypeHTML = "text/html"
ContentTypeOctetStream = "application/octet-stream"
)
// 默认配置
const (
DefaultDialTimeout = 5 * time.Second
DefaultTimeout = 10 * time.Second
DefaultUserAgent = "Starnet/1.0.0"
DefaultFetchRespBody = false
)
// RequestFile 表示要上传的文件
type RequestFile struct {
FormName string // 表单字段名
FileName string // 文件名
FilePath string // 文件路径(如果从文件读取)
FileData io.Reader // 文件数据流
FileSize int64 // 文件大小
FileType string // MIME 类型
}
// UploadProgressFunc 文件上传进度回调函数
type UploadProgressFunc func(filename string, uploaded int64, total int64)
// NetworkConfig 网络配置
type NetworkConfig struct {
Proxy string // 代理地址
DialTimeout time.Duration // 连接超时
Timeout time.Duration // 总超时
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
}
// TLSConfig TLS 配置
type TLSConfig struct {
Config *tls.Config // TLS 配置
SkipVerify bool // 跳过证书验证
}
// DNSConfig DNS 配置
type DNSConfig struct {
CustomIP []string // 直接指定 IP最高优先级
CustomDNS []string // 自定义 DNS 服务器
LookupFunc func(ctx context.Context, host string) ([]net.IPAddr, error) // 自定义解析函数
}
// BodyConfig 请求体配置
type BodyConfig struct {
Bytes []byte // 原始字节
Reader io.Reader // 数据流
FormData map[string][]string // 表单数据
Files []RequestFile // 文件列表
}
// RequestConfig 请求配置(内部使用)
type RequestConfig struct {
Network NetworkConfig
TLS TLSConfig
DNS DNSConfig
Body BodyConfig
Headers http.Header
Cookies []*http.Cookie
Queries map[string][]string
// 其他配置
BasicAuth [2]string // Basic 认证
ContentLength int64 // 手动设置的 Content-Length
AutoCalcContentLength bool // 自动计算 Content-Length
MaxRespBodyBytes int64 // 响应体最大读取字节数(<=0 表示不限制)
UploadProgress UploadProgressFunc // 上传进度回调
// Transport 配置
CustomTransport bool // 是否使用自定义 Transport
Transport *http.Transport // 自定义 Transport
}
// Clone 克隆配置
func (c *RequestConfig) Clone() *RequestConfig {
return &RequestConfig{
Network: NetworkConfig{
Proxy: c.Network.Proxy,
DialTimeout: c.Network.DialTimeout,
Timeout: c.Network.Timeout,
DialFunc: c.Network.DialFunc,
},
TLS: TLSConfig{
Config: cloneTLSConfig(c.TLS.Config),
SkipVerify: c.TLS.SkipVerify,
},
DNS: DNSConfig{
CustomIP: cloneStringSlice(c.DNS.CustomIP),
CustomDNS: cloneStringSlice(c.DNS.CustomDNS),
LookupFunc: c.DNS.LookupFunc,
},
Body: BodyConfig{
Bytes: cloneBytes(c.Body.Bytes),
Reader: c.Body.Reader, // Reader 不可克隆
FormData: cloneStringMapSlice(c.Body.FormData),
Files: cloneFiles(c.Body.Files),
},
Headers: cloneHeader(c.Headers),
Cookies: cloneCookies(c.Cookies),
Queries: cloneStringMapSlice(c.Queries),
BasicAuth: c.BasicAuth,
ContentLength: c.ContentLength,
AutoCalcContentLength: c.AutoCalcContentLength,
MaxRespBodyBytes: c.MaxRespBodyBytes,
UploadProgress: c.UploadProgress,
CustomTransport: c.CustomTransport,
Transport: c.Transport, // Transport 共享
}
}
// RequestOpt 请求选项函数
type RequestOpt func(*Request) error

212
utils.go Normal file
View File

@ -0,0 +1,212 @@
package starnet
import (
"context"
"crypto/tls"
"io"
"net/http"
"net/url"
"strings"
)
// validMethod 验证 HTTP 方法是否有效
func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// isNotToken 检查字符是否不是 token 字符
func isNotToken(r rune) bool {
return !isTokenRune(r)
}
// isTokenRune 检查字符是否是 token 字符
func isTokenRune(r rune) bool {
i := int(r)
return i < 127 && isTokenTable[i]
}
// isTokenTable token 字符表
var isTokenTable = [127]bool{
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
}
// hasPort 检查地址是否包含端口
func hasPort(s string) bool {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}
// removeEmptyPort 移除空端口
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}
// UrlEncode URL 编码
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
// UrlEncodeRaw URL 编码(空格编码为 %20
func UrlEncodeRaw(str string) string {
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
}
// UrlDecode URL 解码
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
// BuildQuery 构建查询字符串
func BuildQuery(data map[string]string) string {
query := url.Values{}
for k, v := range data {
query.Add(k, v)
}
return query.Encode()
}
// BuildPostForm 构建 POST 表单数据
func BuildPostForm(data map[string]string) []byte {
return []byte(BuildQuery(data))
}
// cloneHeader 克隆 Header
func cloneHeader(h http.Header) http.Header {
if h == nil {
return make(http.Header)
}
newHeader := make(http.Header, len(h))
for k, v := range h {
newHeader[k] = append([]string(nil), v...)
}
return newHeader
}
// cloneCookies 克隆 Cookies
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
if cookies == nil {
return nil
}
newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies {
newCookies[i] = &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: c.SameSite,
Raw: c.Raw,
Unparsed: append([]string(nil), c.Unparsed...),
}
}
return newCookies
}
// cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil {
return make(map[string][]string)
}
newMap := make(map[string][]string, len(m))
for k, v := range m {
newMap[k] = append([]string(nil), v...)
}
return newMap
}
// cloneBytes 克隆字节切片
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
newBytes := make([]byte, len(b))
copy(newBytes, b)
return newBytes
}
// cloneStringSlice 克隆字符串切片
func cloneStringSlice(s []string) []string {
if s == nil {
return nil
}
newSlice := make([]string, len(s))
copy(newSlice, s)
return newSlice
}
// cloneFiles 克隆文件列表
func cloneFiles(files []RequestFile) []RequestFile {
if files == nil {
return nil
}
newFiles := make([]RequestFile, len(files))
copy(newFiles, files)
return newFiles
}
// cloneTLSConfig 克隆 TLS 配置
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return nil
}
return cfg.Clone()
}
// copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if progress == nil {
return io.Copy(dst, src)
}
var written int64
buf := make([]byte, 32*1024) // 32KB buffer
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, err := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
// 同步调用进度回调(不使用 goroutine
progress(filename, written, total)
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
// 最后一次进度回调
progress(filename, written, total)
return written, nil
}
return written, err
}
}
}

284
utils_test.go Normal file
View File

@ -0,0 +1,284 @@
package starnet
import (
"net/http"
"testing"
"time"
)
func TestUrlEncodeRaw(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "basic string with space",
input: "hello world",
expected: "hello%20world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "chinese characters",
input: "你好世界",
expected: "%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncodeRaw(tt.input)
if result != tt.expected {
t.Errorf("UrlEncodeRaw(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlEncode(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "space encoded as plus",
input: "hello world",
expected: "hello+world",
},
{
name: "special characters",
input: "hello world!@#$%^&*()_+-=~`",
expected: "hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UrlEncode(tt.input)
if result != tt.expected {
t.Errorf("UrlEncode(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUrlDecode(t *testing.T) {
tests := []struct {
name string
input string
expected string
expectErr bool
}{
{
name: "basic decode",
input: "hello%20world",
expected: "hello world",
expectErr: false,
},
{
name: "plus to space",
input: "hello+world",
expected: "hello world",
expectErr: false,
},
{
name: "special characters",
input: "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60",
expected: "hello world!@#$%^&*()_+-=~`",
expectErr: false,
},
{
name: "invalid encoding",
input: "%zz",
expected: "",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := UrlDecode(tt.input)
if tt.expectErr {
if err == nil {
t.Errorf("UrlDecode(%q) expected error, got nil", tt.input)
}
} else {
if err != nil {
t.Errorf("UrlDecode(%q) unexpected error: %v", tt.input, err)
}
if result != tt.expected {
t.Errorf("UrlDecode(%q) = %q; want %q", tt.input, result, tt.expected)
}
}
})
}
}
func TestBuildQuery(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected string
}{
{
name: "single parameter",
input: map[string]string{
"key": "value",
},
expected: "key=value",
},
{
name: "empty map",
input: map[string]string{},
expected: "",
},
{
name: "nil map",
input: nil,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildQuery(tt.input)
if result != tt.expected {
t.Errorf("BuildQuery(%v) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestBuildPostForm(t *testing.T) {
tests := []struct {
name string
input map[string]string
expected []byte
}{
{
name: "basic form",
input: map[string]string{
"key1": "value1",
},
expected: []byte("key1=value1"),
},
{
name: "empty map",
input: map[string]string{},
expected: []byte(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := BuildPostForm(tt.input)
if string(result) != string(tt.expected) {
t.Errorf("BuildPostForm(%v) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestValidMethod(t *testing.T) {
tests := []struct {
name string
method string
expected bool
}{
{"GET", "GET", true},
{"POST", "POST", true},
{"PUT", "PUT", true},
{"DELETE", "DELETE", true},
{"PATCH", "PATCH", true},
{"OPTIONS", "OPTIONS", true},
{"HEAD", "HEAD", true},
{"TRACE", "TRACE", true},
{"CONNECT", "CONNECT", true},
{"invalid with space", "GET POST", false},
{"invalid with special char", "GET<>", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validMethod(tt.method)
if result != tt.expected {
t.Errorf("validMethod(%q) = %v; want %v", tt.method, result, tt.expected)
}
})
}
}
func TestCloneCookies_FullFields(t *testing.T) {
expire := time.Now().Add(2 * time.Hour)
src := []*http.Cookie{
{
Name: "sid",
Value: "abc123",
Path: "/",
Domain: "example.com",
Expires: expire,
RawExpires: expire.UTC().Format(time.RFC1123),
MaxAge: 3600,
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Raw: "sid=abc123; Path=/; HttpOnly",
Unparsed: []string{"Priority=High", "Partitioned"},
},
}
got := cloneCookies(src)
if got == nil || len(got) != 1 {
t.Fatalf("cloneCookies() len=%v; want 1", len(got))
}
// 指针应不同(不是浅拷贝)
if got[0] == src[0] {
t.Fatal("cookie pointer should be different (deep copy expected)")
}
// 字段值应一致
s := src[0]
g := got[0]
if g.Name != s.Name ||
g.Value != s.Value ||
g.Path != s.Path ||
g.Domain != s.Domain ||
!g.Expires.Equal(s.Expires) ||
g.RawExpires != s.RawExpires ||
g.MaxAge != s.MaxAge ||
g.Secure != s.Secure ||
g.HttpOnly != s.HttpOnly ||
g.SameSite != s.SameSite ||
g.Raw != s.Raw {
t.Fatalf("cloned cookie fields mismatch:\n got=%+v\n src=%+v", g, s)
}
// Unparsed 内容一致
if len(g.Unparsed) != len(s.Unparsed) {
t.Fatalf("Unparsed len=%d; want %d", len(g.Unparsed), len(s.Unparsed))
}
for i := range s.Unparsed {
if g.Unparsed[i] != s.Unparsed[i] {
t.Fatalf("Unparsed[%d]=%q; want %q", i, g.Unparsed[i], s.Unparsed[i])
}
}
// 验证 Unparsed 是深拷贝(修改源不影响目标)
src[0].Unparsed[0] = "Modified=Yes"
if got[0].Unparsed[0] == "Modified=Yes" {
t.Fatal("Unparsed should be deep-copied, but was affected by source mutation")
}
}