package starnet import ( "fmt" "io" "net/http" "os" "path/filepath" "strings" "sync/atomic" "testing" ) func TestRequestDoTwiceRebuildsExecutionState(t *testing.T) { var attempts int32 req := NewSimpleRequest("http://example.com/path", http.MethodPost). SetHeader("X-Test", "one"). AddQuery("q", "v"). SetBodyReader(strings.NewReader("payload")) req.client = &Client{client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { if err := r.Context().Err(); err != nil { t.Fatalf("request context already done: %v", err) } if values := r.Header.Values("X-Test"); len(values) != 1 || values[0] != "one" { t.Fatalf("header values=%v", values) } if values := r.URL.Query()["q"]; len(values) != 1 || values[0] != "v" { t.Fatalf("query values=%v", values) } body, err := io.ReadAll(r.Body) if err != nil { return nil, err } _ = r.Body.Close() if string(body) != "payload" { t.Fatalf("body=%q", string(body)) } n := atomic.AddInt32(&attempts, 1) return &http.Response{ StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(fmt.Sprintf("ok-%d", n))), Request: r, }, nil }), }} resp1, err := req.Do() if err != nil { t.Fatalf("first Do() error: %v", err) } if err := resp1.Close(); err != nil { t.Fatalf("first Close() error: %v", err) } resp2, err := req.Do() if err != nil { t.Fatalf("second Do() error: %v", err) } defer resp2.Close() if got := atomic.LoadInt32(&attempts); got != 2 { t.Fatalf("attempts=%d; want 2", got) } } func TestRequestPrepareRawDynamicPathInjectsAggregatedRequestContext(t *testing.T) { rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/resource", nil) if err != nil { t.Fatalf("NewRequest() error: %v", err) } req := NewSimpleRequest("", http.MethodGet). SetRawRequest(rawReq). SetProxy("http://proxy.example:8080"). SetCustomIP([]string{"127.0.0.1"}). SetSkipTLSVerify(true). SetTLSServerName("override.example") if err := req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } raw := req.execCtx.Value(ctxKeyRequestContext) rc, ok := raw.(*RequestContext) if !ok || rc == nil { t.Fatalf("expected request context, got %#v", raw) } if rc.Proxy != "http://proxy.example:8080" { t.Fatalf("proxy=%q", rc.Proxy) } if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" { t.Fatalf("custom ip=%v", rc.CustomIP) } if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify { t.Fatalf("tls config=%#v", rc.TLSConfig) } if rc.TLSServerName != "override.example" { t.Fatalf("tls server name=%q", rc.TLSServerName) } } func TestRequestSetFormDataOverridesBytesBody(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodPost). SetBodyString("stale"). SetFormData(map[string][]string{"k": []string{"v"}}) if req.config.Body.Mode != bodyModeForm { t.Fatalf("body mode=%v", req.config.Body.Mode) } if req.config.Body.Reader != nil || req.config.Body.Bytes != nil || len(req.config.Body.Files) != 0 { t.Fatalf("unexpected stale body state: %#v", req.config.Body) } if err := req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } body, err := req.httpReq.GetBody() if err != nil { t.Fatalf("GetBody() error: %v", err) } defer body.Close() data, err := io.ReadAll(body) if err != nil { t.Fatalf("ReadAll() error: %v", err) } if string(data) != "k=v" { t.Fatalf("body=%q; want k=v", string(data)) } } func TestRequestAddFileClearsPreviousBytesBody(t *testing.T) { tmpDir := t.TempDir() filePath := filepath.Join(tmpDir, "payload.txt") if err := os.WriteFile(filePath, []byte("file-body"), 0644); err != nil { t.Fatalf("WriteFile() error: %v", err) } req := NewSimpleRequest("http://example.com", http.MethodPost). SetJSON(map[string]string{"old": "json-only"}). AddFile("file", filePath) if req.config.Body.Mode != bodyModeMultipart { t.Fatalf("body mode=%v", req.config.Body.Mode) } if req.config.Body.Reader != nil || req.config.Body.Bytes != nil { t.Fatalf("unexpected stale simple body state: %#v", req.config.Body) } if err := req.prepare(); err != nil { t.Fatalf("prepare() error: %v", err) } data, err := io.ReadAll(req.httpReq.Body) if err != nil { t.Fatalf("ReadAll() error: %v", err) } if !strings.Contains(req.httpReq.Header.Get("Content-Type"), "multipart/form-data") { t.Fatalf("content-type=%q", req.httpReq.Header.Get("Content-Type")) } if !strings.Contains(string(data), "file-body") { t.Fatalf("multipart body missing file content: %q", string(data)) } if strings.Contains(string(data), "json-only") { t.Fatalf("multipart body still contains stale json: %q", string(data)) } }