package starnet import ( "context" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestNewSimpleRequest(t *testing.T) { tests := []struct { name string url string method string expectErr bool }{ { name: "valid GET request", url: "https://example.com", method: "GET", expectErr: false, }, { name: "valid POST request", url: "https://example.com", method: "POST", expectErr: false, }, { name: "invalid URL", url: "://invalid", method: "GET", expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, err := NewRequest(tt.url, tt.method) if tt.expectErr { if err == nil && req.Err() == nil { t.Errorf("NewRequest() expected error, got nil") } } else { if err != nil { t.Errorf("NewRequest() unexpected error: %v", err) } if req.Method() != strings.ToUpper(tt.method) { t.Errorf("Method = %v; want %v", req.Method(), tt.method) } } }) } } func TestRequestMethods(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Method", r.Method) w.WriteHeader(http.StatusOK) w.Write([]byte(r.Method)) })) defer server.Close() methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} for _, method := range methods { t.Run(method, func(t *testing.T) { req := NewSimpleRequest(server.URL, method) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK) } if method != "HEAD" { body, _ := resp.Body().String() if body != method { t.Errorf("Body = %v; want %v", body, method) } } }) } } func TestRequestSetMethod(t *testing.T) { req := NewSimpleRequest("https://example.com", "GET") req.SetMethod("POST") if req.Method() != "POST" { t.Errorf("Method = %v; want POST", req.Method()) } req.SetMethod("invalid method!") if req.Err() == nil { t.Error("SetMethod with invalid method should set error") } } func TestRequestSetURL(t *testing.T) { req := NewSimpleRequest("https://example.com", "GET") req.SetURL("https://newexample.com") if req.URL() != "https://newexample.com" { t.Errorf("URL = %v; want https://newexample.com", req.URL()) } req2 := NewSimpleRequest("https://example.com", "GET") req2.SetURL("://invalid") if req2.Err() == nil { t.Error("SetURL with invalid URL should set error") } } func TestRequestClone(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-Test") != "value" { w.WriteHeader(http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() req := NewSimpleRequest(server.URL, "GET"). SetHeader("X-Test", "value") // 第一次请求 resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } resp.Close() // 克隆请求 cloned := req.Clone() cloned.SetHeader("X-Extra", "extra") // 克隆的请求应该也能成功 resp2, err := cloned.Do() if err != nil { t.Fatalf("Cloned Do() error: %v", err) } defer resp2.Close() if resp2.StatusCode != http.StatusOK { t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK) } } func TestRequestContext(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(100 * time.Millisecond) w.WriteHeader(http.StatusOK) })) defer server.Close() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() req := NewSimpleRequest(server.URL, "GET").SetContext(ctx) _, err := req.Do() if err == nil { t.Error("Expected timeout error, got nil") } }