- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
336 lines
8.6 KiB
Go
336 lines
8.6 KiB
Go
package starnet
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return fn(req)
|
|
}
|
|
|
|
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", http.MethodPost).
|
|
SetHeader("X-Test", "one").
|
|
SetBodyString("first")
|
|
req.client = &Client{client: &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_ = r.Body.Close()
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
|
|
Request: r,
|
|
}, nil
|
|
}),
|
|
}}
|
|
|
|
if _, err := req.HTTPClient(); err != nil {
|
|
t.Fatalf("HTTPClient() error: %v", err)
|
|
}
|
|
|
|
req.SetHeader("X-Test", "two").SetBodyString("second")
|
|
|
|
resp, err := req.Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
body, err := resp.Body().String()
|
|
if err != nil {
|
|
t.Fatalf("Body().String() error: %v", err)
|
|
}
|
|
if body != "two:second" {
|
|
t.Fatalf("body=%q; want %q", body, "two:second")
|
|
}
|
|
}
|
|
|
|
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
|
|
var attempts int32
|
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
|
req.client = &Client{client: &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
|
return &http.Response{
|
|
StatusCode: http.StatusNoContent,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Request: r,
|
|
}, nil
|
|
}
|
|
select {
|
|
case <-time.After(50 * time.Millisecond):
|
|
return &http.Response{
|
|
StatusCode: http.StatusNoContent,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
Request: r,
|
|
}, nil
|
|
case <-r.Context().Done():
|
|
return nil, r.Context().Err()
|
|
}
|
|
}),
|
|
}}
|
|
|
|
resp, err := req.Do()
|
|
if err != nil {
|
|
t.Fatalf("first Do() error: %v", err)
|
|
}
|
|
_ = resp.Close()
|
|
|
|
_, err = req.SetTimeout(10 * time.Millisecond).Do()
|
|
if err == nil {
|
|
t.Fatal("second Do() succeeded; want timeout error")
|
|
}
|
|
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Fatalf("second Do() error=%v; want timeout", err)
|
|
}
|
|
}
|
|
|
|
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", http.MethodPost)
|
|
|
|
pr, pw := io.Pipe()
|
|
writer := multipart.NewWriter(pw)
|
|
done := make(chan struct{})
|
|
go func() {
|
|
_, _ = io.Copy(io.Discard, pr)
|
|
_ = pr.Close()
|
|
close(done)
|
|
}()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
err := req.writeFile(ctx, writer, RequestFile{
|
|
FormName: "file",
|
|
FileName: "payload.txt",
|
|
FileData: strings.NewReader("payload"),
|
|
FileSize: int64(len("payload")),
|
|
})
|
|
_ = writer.Close()
|
|
_ = pw.Close()
|
|
<-done
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
|
|
}
|
|
}
|
|
|
|
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
|
|
}
|
|
}
|
|
|
|
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *Request
|
|
want string
|
|
}{
|
|
{
|
|
name: "bytes",
|
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
|
|
want: "payload",
|
|
},
|
|
{
|
|
name: "bytes-reader",
|
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
|
|
want: "payload",
|
|
},
|
|
{
|
|
name: "strings-reader",
|
|
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
|
|
want: "payload",
|
|
},
|
|
{
|
|
name: "form-data",
|
|
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
|
|
want: "k=v",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if err := tt.req.prepare(); err != nil {
|
|
t.Fatalf("prepare() error: %v", err)
|
|
}
|
|
if tt.req.httpReq.GetBody == nil {
|
|
t.Fatal("GetBody is nil")
|
|
}
|
|
|
|
body, err := tt.req.httpReq.GetBody()
|
|
if err != nil {
|
|
t.Fatalf("GetBody() error: %v", err)
|
|
}
|
|
defer body.Close()
|
|
|
|
data, err := io.ReadAll(body)
|
|
if err != nil {
|
|
t.Fatalf("ReadAll() error: %v", err)
|
|
}
|
|
if string(data) != tt.want {
|
|
t.Fatalf("body=%q; want %q", string(data), tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type replayRoundTripper struct {
|
|
attempts int
|
|
bodies []string
|
|
}
|
|
|
|
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
body, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_ = req.Body.Close()
|
|
|
|
rt.attempts++
|
|
rt.bodies = append(rt.bodies, string(body))
|
|
if rt.attempts == 1 {
|
|
return nil, errors.New("first target failed")
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
Request: req,
|
|
}, nil
|
|
}
|
|
|
|
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
|
SetBodyReader(strings.NewReader("payload"))
|
|
|
|
if err := req.prepare(); err != nil {
|
|
t.Fatalf("prepare() error: %v", err)
|
|
}
|
|
|
|
rt := &replayRoundTripper{}
|
|
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
|
if err != nil {
|
|
t.Fatalf("roundTripResolvedTargets() error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if len(rt.bodies) != 2 {
|
|
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
|
|
}
|
|
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
|
|
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
|
|
}
|
|
}
|
|
|
|
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
|
|
SetBodyReader(strings.NewReader("payload"))
|
|
|
|
if err := req.prepare(); err != nil {
|
|
t.Fatalf("prepare() error: %v", err)
|
|
}
|
|
|
|
rt := &replayRoundTripper{}
|
|
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
|
|
if err == nil {
|
|
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
|
|
}
|
|
if len(rt.bodies) != 1 {
|
|
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
|
|
}
|
|
if rt.bodies[0] != "payload" {
|
|
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
|
|
}
|
|
}
|
|
|
|
func TestRetryReplayableReaderBody(t *testing.T) {
|
|
var attempts int32
|
|
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
|
|
SetBodyReader(strings.NewReader("payload")).
|
|
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
|
|
req.client = &Client{client: &http.Client{
|
|
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_ = r.Body.Close()
|
|
if string(body) != "payload" {
|
|
t.Fatalf("body=%q; want payload", string(body))
|
|
}
|
|
|
|
if atomic.AddInt32(&attempts, 1) == 1 {
|
|
return &http.Response{
|
|
StatusCode: http.StatusServiceUnavailable,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("retry")),
|
|
Request: r,
|
|
}, nil
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Header: make(http.Header),
|
|
Body: io.NopCloser(strings.NewReader("ok")),
|
|
Request: r,
|
|
}, nil
|
|
}),
|
|
}}
|
|
|
|
resp, err := req.Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
if got := atomic.LoadInt32(&attempts); got != 2 {
|
|
t.Fatalf("attempts=%d; want 2", got)
|
|
}
|
|
}
|
|
|
|
func TestWithProxyInvalidReturnsError(t *testing.T) {
|
|
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
|
if err == nil {
|
|
t.Fatal("NewRequest() succeeded; want invalid proxy error")
|
|
}
|
|
}
|
|
|
|
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
|
|
client := NewClientNoErr()
|
|
|
|
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
|
|
if err == nil {
|
|
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
|
|
}
|
|
}
|
|
|
|
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
|
|
_, err := NewClient(WithProxy("://bad-proxy"))
|
|
if err == nil {
|
|
t.Fatal("NewClient() succeeded; want invalid proxy error")
|
|
}
|
|
}
|