starnet/request_prepare_regression_test.go

336 lines
8.6 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func TestRequestPreparedMutationReappliesHeadersAndBody(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost).
SetHeader("X-Test", "one").
SetBodyString("first")
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test") + ":" + string(body))),
Request: r,
}, nil
}),
}}
if _, err := req.HTTPClient(); err != nil {
t.Fatalf("HTTPClient() error: %v", err)
}
req.SetHeader("X-Test", "two").SetBodyString("second")
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
body, err := resp.Body().String()
if err != nil {
t.Fatalf("Body().String() error: %v", err)
}
if body != "two:second" {
t.Fatalf("body=%q; want %q", body, "two:second")
}
}
func TestRequestPreparedMutationReappliesTimeout(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com", http.MethodGet)
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
}
select {
case <-time.After(50 * time.Millisecond):
return &http.Response{
StatusCode: http.StatusNoContent,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
Request: r,
}, nil
case <-r.Context().Done():
return nil, r.Context().Err()
}
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
_ = resp.Close()
_, err = req.SetTimeout(10 * time.Millisecond).Do()
if err == nil {
t.Fatal("second Do() succeeded; want timeout error")
}
if !IsTimeout(err) && !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("second Do() error=%v; want timeout", err)
}
}
func TestWriteFileUsesExecContextWithoutProgressHook(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodPost)
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
done := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, pr)
_ = pr.Close()
close(done)
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := req.writeFile(ctx, writer, RequestFile{
FormName: "file",
FileName: "payload.txt",
FileData: strings.NewReader("payload"),
FileSize: int64(len("payload")),
})
_ = writer.Close()
_ = pw.Close()
<-done
if !errors.Is(err, context.Canceled) {
t.Fatalf("writeFile() error=%v; want context.Canceled", err)
}
}
func TestCopyWithProgressHonorsCanceledContextWithoutHook(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := copyWithProgress(ctx, io.Discard, strings.NewReader("payload"), "payload.txt", int64(len("payload")), nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("copyWithProgress() error=%v; want context.Canceled", err)
}
}
func TestPrepareSetsGetBodyForReplayableBodies(t *testing.T) {
tests := []struct {
name string
req *Request
want string
}{
{
name: "bytes",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBody([]byte("payload")),
want: "payload",
},
{
name: "bytes-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(bytes.NewReader([]byte("payload"))),
want: "payload",
},
{
name: "strings-reader",
req: NewSimpleRequest("http://example.com", http.MethodPost).SetBodyReader(strings.NewReader("payload")),
want: "payload",
},
{
name: "form-data",
req: NewSimpleRequest("http://example.com", http.MethodPost).AddFormData("k", "v"),
want: "k=v",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
if tt.req.httpReq.GetBody == nil {
t.Fatal("GetBody is nil")
}
body, err := tt.req.httpReq.GetBody()
if err != nil {
t.Fatalf("GetBody() error: %v", err)
}
defer body.Close()
data, err := io.ReadAll(body)
if err != nil {
t.Fatalf("ReadAll() error: %v", err)
}
if string(data) != tt.want {
t.Fatalf("body=%q; want %q", string(data), tt.want)
}
})
}
}
type replayRoundTripper struct {
attempts int
bodies []string
}
func (rt *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
_ = req.Body.Close()
rt.attempts++
rt.bodies = append(rt.bodies, string(body))
if rt.attempts == 1 {
return nil, errors.New("first target failed")
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: req,
}, nil
}
func TestRoundTripResolvedTargetsReplaysPreparedBody(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
resp, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err != nil {
t.Fatalf("roundTripResolvedTargets() error: %v", err)
}
defer resp.Body.Close()
if len(rt.bodies) != 2 {
t.Fatalf("attempt bodies=%v; want 2 attempts", rt.bodies)
}
if rt.bodies[0] != "payload" || rt.bodies[1] != "payload" {
t.Fatalf("attempt bodies=%v; want both payload", rt.bodies)
}
}
func TestRoundTripResolvedTargetsDoesNotFallbackNonIdempotentRequest(t *testing.T) {
req := NewSimpleRequest("http://example.com/upload", http.MethodPost).
SetBodyReader(strings.NewReader("payload"))
if err := req.prepare(); err != nil {
t.Fatalf("prepare() error: %v", err)
}
rt := &replayRoundTripper{}
_, err := roundTripResolvedTargets(rt, req.httpReq, []string{"127.0.0.2:80", "127.0.0.1:80"})
if err == nil {
t.Fatal("roundTripResolvedTargets() succeeded; want first target error")
}
if len(rt.bodies) != 1 {
t.Fatalf("attempt bodies=%v; want only first target attempt", rt.bodies)
}
if rt.bodies[0] != "payload" {
t.Fatalf("attempt body=%q; want payload", rt.bodies[0])
}
}
func TestRetryReplayableReaderBody(t *testing.T) {
var attempts int32
req := NewSimpleRequest("http://example.com/upload", http.MethodPut).
SetBodyReader(strings.NewReader("payload")).
SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0))
req.client = &Client{client: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
_ = r.Body.Close()
if string(body) != "payload" {
t.Fatalf("body=%q; want payload", string(body))
}
if atomic.AddInt32(&attempts, 1) == 1 {
return &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("retry")),
Request: r,
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("ok")),
Request: r,
}, nil
}),
}}
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if got := atomic.LoadInt32(&attempts); got != 2 {
t.Fatalf("attempts=%d; want 2", got)
}
}
func TestWithProxyInvalidReturnsError(t *testing.T) {
_, err := NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewRequest() succeeded; want invalid proxy error")
}
}
func TestClientNewRequestWithInvalidProxyReturnsError(t *testing.T) {
client := NewClientNoErr()
_, err := client.NewRequest("http://example.com", http.MethodGet, WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("Client.NewRequest() succeeded; want invalid proxy error")
}
}
func TestNewClientWithInvalidProxyReturnsError(t *testing.T) {
_, err := NewClient(WithProxy("://bad-proxy"))
if err == nil {
t.Fatal("NewClient() succeeded; want invalid proxy error")
}
}