package starnet import ( "bytes" "context" "errors" "io" "mime/multipart" "net/http" "strings" "sync/atomic" "testing" "time" ) type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return fn(req) } func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodPost). SetHeader("X-Test", "one"). SetBodyString("first") req.client = &Client{client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { body, err := io.ReadAll(r.Body) if err != nil { return nil, err } _ = r.Body.Close() return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))), Request: r, }, nil }), }} if _, err := req.HTTPClient(); err != nil { t.Fatalf("HTTPClient() error: %v", err) } req.SetHeader("X-Test", "two").SetBodyString("second") resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() body, err := resp.Body().String() if err != nil { t.Fatalf("Body().String() error: %v", err) } if body != "two:second" { t.Fatalf("body=%q; want %q", body, "two:second") } } func TestRequestPreparedMutationReappliesTimeout(t *testing.T) { var attempts int32 req := NewSimpleRequest("http://example.com", http.MethodGet) req.client = &Client{client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { if atomic.AddInt32(&attempts, 1) == 1 { return &http.Response{ StatusCode: http.StatusNoContent, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), Request: r, }, nil } select { case <-time.After(50 * time.Millisecond): return &http.Response{ StatusCode: http.StatusNoContent, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("")), Request: r, }, nil case <-r.Context().Done(): return nil, r.Context().Err() } }), }} resp, err := req.Do() if err != nil { t.Fatalf("first Do() error: %v", err) } _ = resp.Close() _, err = req.SetTimeout(10 * time.Millisecond).Do() if err == nil { t.Fatal("second Do() succeeded; want timeout error") } if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("second Do() error=%v; want timeout", err) } } func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodPost) pr, pw := io.Pipe() writer := multipart.NewWriter(pw) done := make(chan struct{}) go func() { _, _ = io.Copy(io.Discard, pr) _ = pr.Close() close(done) }() ctx, cancel := context.WithCancel(context.Background()) cancel() err := req.writeFile(ctx, writer, RequestFile{ FormName: "file", FileName: "payload.txt", FileData: strings.NewReader("payload"), FileSize: int64(len("payload")), }) _ = writer.Close() _ = pw.Close() <-done if !errors.Is(err, context.Canceled) { t.Fatalf("writeFile() error=%v; want context.Canceled", err) } } func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil) if !errors.Is(err, context.Canceled) { t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err) } } func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) { tests := []struct { name string req *Request want string }{ { name: "bytes", req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")), want: "payload", }, { name: "bytes-reader", req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))), want: "payload", }, { name: "strings-reader", req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")), want: "payload", }, { name: "form-data", req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"), want: "k=v", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } if tt.req.httpReq.GetBody == nil { t.Fatal("GetBody is nil") } body, err := tt.req.httpReq.GetBody() if err != nil { t.Fatalf("GetBody() error: %v", err) } defer body.Close() data, err := io.ReadAll(body) if err != nil { t.Fatalf("ReadAll() error: %v", err) } if string(data) != tt.want { t.Fatalf("body=%q; want %q", string(data), tt.want) } }) } } type replayRoundTripper struct { attempts int bodies []string } func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { body, err := io.ReadAll(req.Body) if err != nil { return nil, err } _ = req.Body.Close() rt.attempts++ rt.bodies = append(rt.bodies, string(body)) if rt.attempts == 1 { return nil, errors.New("first target failed") } return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("ok")), Request: req, }, nil } func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) { req := NewSimpleRequest("http://example.com/upload", http.MethodPut). SetBodyReader(strings.NewReader("payload")) if err := req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } rt := &replayRoundTripper{} resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"}) if err != nil { t.Fatalf("roundTripResolvedTargets() error: %v", err) } defer resp.Body.Close() if len(rt.bodies) != 2 { t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies) } if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" { t.Fatalf("attempt bodies=%v; want both payload", rt.bodies) } } func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) { req := NewSimpleRequest("http://example.com/upload", http.MethodPost). SetBodyReader(strings.NewReader("payload")) if err := req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } rt := &replayRoundTripper{} _, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"}) if err == nil { t.Fatal("roundTripResolvedTargets() succeeded; want first target error") } if len(rt.bodies) != 1 { t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies) } if rt.bodies[0] != "payload" { t.Fatalf("attempt body=%q; want payload", rt.bodies[0]) } } func TestRetryReplayableReaderBody(t *testing.T) { var attempts int32 req := NewSimpleRequest("http://example.com/upload", http.MethodPut). SetBodyReader(strings.NewReader("payload")). SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)) req.client = &Client{client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { body, err := io.ReadAll(r.Body) if err != nil { return nil, err } _ = r.Body.Close() if string(body) != "payload" { t.Fatalf("body=%q; want payload", string(body)) } if atomic.AddInt32(&attempts, 1) == 1 { return &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("retry")), Request: r, }, nil } return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("ok")), Request: r, }, nil }), }} resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if got := atomic.LoadInt32(&attempts); got != 2 { t.Fatalf("attempts=%d; want 2", got) } } func TestWithProxyInvalidReturnsError(t *testing.T) { _, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy")) if err == nil { t.Fatal("NewRequest() succeeded; want invalid proxy error") } } func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) { client := NewClientNoErr() _, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy")) if err == nil { t.Fatal("Client.NewRequest() succeeded; want invalid proxy error") } } func TestNewClientWithInvalidProxyReturnsError(t *testing.T) { _, err := NewClient(WithProxy("://bad-proxy")) if err == nil { t.Fatal("NewClient() succeeded; want invalid proxy error") } }