package starnet import ( "io" "net/http" "net/url" "strings" "sync/atomic" "testing" ) type stateRoundTripperFunc func(*http.Request) (*http.Response, error) func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return fn(req) } func TestSetContextNilUsesBackground(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodGet) req.client = &Client{client: &http.Client{ Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) { if r.Context() == nil { t.Fatal("request context is nil") } return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("ok")), Request: r, }, nil }), }} resp, err := req.SetContext(nil).Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if req.Context() == nil { t.Fatal("request Context() is nil") } } func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) { var hits int32 req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil)) if err != nil { t.Fatalf("NewRequest() error: %v", err) } req.client = &Client{client: &http.Client{ Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) { if r.Context() == nil { t.Fatal("retry request context is nil") } if atomic.AddInt32(&hits, 1) == 1 { return &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("retry")), Request: r, }, nil } return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader("ok")), Request: r, }, nil }), }} resp, err := req. SetTimeout(DefaultTimeout). SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if got := atomic.LoadInt32(&hits); got != 2 { t.Fatalf("hits=%d; want 2", got) } } func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) { rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload")) if err != nil { t.Fatalf("NewRequest() error: %v", err) } rawReq.Header.Set("X-Test", "one") req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq) cloned := req.Clone() if cloned.Err() != nil { t.Fatalf("Clone() err = %v", cloned.Err()) } if cloned.RawRequest() == rawReq { t.Fatal("raw request pointer reused") } cloned.RawRequest().Header.Set("X-Test", "two") if rawReq.Header.Get("X-Test") != "one" { t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test")) } body, err := cloned.RawRequest().GetBody() if err != nil { t.Fatalf("GetBody() error: %v", err) } defer body.Close() data, err := io.ReadAll(body) if err != nil { t.Fatalf("ReadAll() error: %v", err) } if string(data) != "payload" { t.Fatalf("body=%q; want payload", string(data)) } } func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) { rawReq := &http.Request{ Method: http.MethodPost, URL: mustParseURL(t, "http://example.com/upload"), Header: make(http.Header), Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))), } req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq) cloned := req.Clone() if cloned.Err() == nil { t.Fatal("Clone() should fail for non-replayable raw body") } if !strings.Contains(cloned.Err().Error(), "non-replayable") { t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err()) } } func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) { rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil) if err != nil { t.Fatalf("NewRequest() error: %v", err) } req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode() if req.Err() == nil { t.Fatal("DisableRawMode() should set error") } if !strings.Contains(req.Err().Error(), "cannot disable raw mode") { t.Fatalf("DisableRawMode() err=%v", req.Err()) } if !req.doRaw { t.Fatal("request should remain in raw mode") } } func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() parsed, err := url.Parse(raw) if err != nil { t.Fatalf("url.Parse() error: %v", err) } return parsed }