starnet/request_test.go
2026-03-08 20:19:40 +08:00

173 lines
3.9 KiB
Go

package starnet
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestNewSimpleRequest(t *testing.T) {
tests := []struct {
name string
url string
method string
expectErr bool
}{
{
name: "valid GET request",
url: "https://example.com",
method: "GET",
expectErr: false,
},
{
name: "valid POST request",
url: "https://example.com",
method: "POST",
expectErr: false,
},
{
name: "invalid URL",
url: "://invalid",
method: "GET",
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.url, tt.method)
if tt.expectErr {
if err == nil && req.Err() == nil {
t.Errorf("NewRequest() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("NewRequest() unexpected error: %v", err)
}
if req.Method() != strings.ToUpper(tt.method) {
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
}
}
})
}
}
func TestRequestMethods(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Method", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(r.Method))
}))
defer server.Close()
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req := NewSimpleRequest(server.URL, method)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
}
if method != "HEAD" {
body, _ := resp.Body().String()
if body != method {
t.Errorf("Body = %v; want %v", body, method)
}
}
})
}
}
func TestRequestSetMethod(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetMethod("POST")
if req.Method() != "POST" {
t.Errorf("Method = %v; want POST", req.Method())
}
req.SetMethod("invalid method!")
if req.Err() == nil {
t.Error("SetMethod with invalid method should set error")
}
}
func TestRequestSetURL(t *testing.T) {
req := NewSimpleRequest("https://example.com", "GET")
req.SetURL("https://newexample.com")
if req.URL() != "https://newexample.com" {
t.Errorf("URL = %v; want https://newexample.com", req.URL())
}
req2 := NewSimpleRequest("https://example.com", "GET")
req2.SetURL("://invalid")
if req2.Err() == nil {
t.Error("SetURL with invalid URL should set error")
}
}
func TestRequestClone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Test") != "value" {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
defer server.Close()
req := NewSimpleRequest(server.URL, "GET").
SetHeader("X-Test", "value")
// 第一次请求
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
resp.Close()
// 克隆请求
cloned := req.Clone()
cloned.SetHeader("X-Extra", "extra")
// 克隆的请求应该也能成功
resp2, err := cloned.Do()
if err != nil {
t.Fatalf("Cloned Do() error: %v", err)
}
defer resp2.Close()
if resp2.StatusCode != http.StatusOK {
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
}
}
func TestRequestContext(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
_, err := req.Do()
if err == nil {
t.Error("Expected timeout error, got nil")
}
}