starnet/request_state_boundary_test.go

169 lines
4.3 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"io"
"net/http"
"net/url"
"strings"
"sync/atomic"
"testing"
)
type stateRoundTripperFunc func(*http.Request) (*http.Response, error)
func (fn stateRoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestSetContextNilUsesBackground(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("request context is nil")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.SetContext(nil).Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if req.Context() == nil {
t.Fatal("request Context() is nil")
}
}
func TestWithContextNilRetryPathDoesNotPanic(t *testing.T) {
var hits int32
req, err := NewRequest("http://example.com", http.MethodGet, WithContext(nil))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req.client = &Client{client: &http.Client{
Transport: stateRoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Context() == nil {
t.Fatal("retry request context is nil")
}
if atomic.AddInt32(&hits, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.
SetTimeout(DefaultTimeout).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("hits=%d; want 2", got)
}
}
func TestCloneRawRequestCreatesIndependentCopy(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://example.com/upload", strings.NewReader("payload"))
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
rawReq.Header.Set("X-Test", "one")
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() != nil {
t.Fatalf("Clone() err = %v", cloned.Err())
}
if cloned.RawRequest() == rawReq {
t.Fatal("raw request pointer reused")
}
cloned.RawRequest().Header.Set("X-Test", "two")
if rawReq.Header.Get("X-Test") != "one" {
t.Fatalf("original header mutated: %q", rawReq.Header.Get("X-Test"))
}
body, err := cloned.RawRequest().GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != "payload" {
t.Fatalf("body=%q; want payload", string(data))
}
}
func TestCloneRawRequestWithNonReplayableBodyFailsExplicitly(t *testing.T) {
rawReq := &http.Request{
Method: http.MethodPost,
URL: mustParseURL(t, "http://example.com/upload"),
Header: make(http.Header),
Body: io.NopCloser(io.MultiReader(strings.NewReader("payload"))),
}
req := NewSimpleRequest("", http.MethodPost).SetRawRequest(rawReq)
cloned := req.Clone()
if cloned.Err() == nil {
t.Fatal("Clone() should fail for non-replayable raw body")
}
if !strings.Contains(cloned.Err().Error(), "non-replayable") {
t.Fatalf("Clone() err=%v; want non-replayable body error", cloned.Err())
}
}
func TestDisableRawModeAfterSetRawRequestReturnsError(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err != nil {
t.Fatalf("NewRequest() error: %v", err)
}
req := NewSimpleRequest("", http.MethodGet).SetRawRequest(rawReq).DisableRawMode()
if req.Err() == nil {
t.Fatal("DisableRawMode() should set error")
}
if !strings.Contains(req.Err().Error(), "cannot disable raw mode") {
t.Fatalf("DisableRawMode() err=%v", req.Err())
}
if !req.doRaw {
t.Fatal("request should remain in raw mode")
}
}
func mustParseURL(t *testing.T, raw string) *url.URL {
t.Helper()
parsed, err := url.Parse(raw)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
return parsed
}