173 lines
3.9 KiB
Go
173 lines
3.9 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewSimpleRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
url string
|
|
method string
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "valid GET request",
|
|
url: "https://example.com",
|
|
method: "GET",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "valid POST request",
|
|
url: "https://example.com",
|
|
method: "POST",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "invalid URL",
|
|
url: "://invalid",
|
|
method: "GET",
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, err := NewRequest(tt.url, tt.method)
|
|
if tt.expectErr {
|
|
if err == nil && req.Err() == nil {
|
|
t.Errorf("NewRequest() expected error, got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("NewRequest() unexpected error: %v", err)
|
|
}
|
|
if req.Method() != strings.ToUpper(tt.method) {
|
|
t.Errorf("Method = %v; want %v", req.Method(), tt.method)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestMethods(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Method", r.Method)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(r.Method))
|
|
}))
|
|
defer server.Close()
|
|
|
|
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
|
|
|
for _, method := range methods {
|
|
t.Run(method, func(t *testing.T) {
|
|
req := NewSimpleRequest(server.URL, method)
|
|
resp, err := req.Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("StatusCode = %v; want %v", resp.StatusCode, http.StatusOK)
|
|
}
|
|
|
|
if method != "HEAD" {
|
|
body, _ := resp.Body().String()
|
|
if body != method {
|
|
t.Errorf("Body = %v; want %v", body, method)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestSetMethod(t *testing.T) {
|
|
req := NewSimpleRequest("https://example.com", "GET")
|
|
|
|
req.SetMethod("POST")
|
|
if req.Method() != "POST" {
|
|
t.Errorf("Method = %v; want POST", req.Method())
|
|
}
|
|
|
|
req.SetMethod("invalid method!")
|
|
if req.Err() == nil {
|
|
t.Error("SetMethod with invalid method should set error")
|
|
}
|
|
}
|
|
|
|
func TestRequestSetURL(t *testing.T) {
|
|
req := NewSimpleRequest("https://example.com", "GET")
|
|
|
|
req.SetURL("https://newexample.com")
|
|
if req.URL() != "https://newexample.com" {
|
|
t.Errorf("URL = %v; want https://newexample.com", req.URL())
|
|
}
|
|
|
|
req2 := NewSimpleRequest("https://example.com", "GET")
|
|
req2.SetURL("://invalid")
|
|
if req2.Err() == nil {
|
|
t.Error("SetURL with invalid URL should set error")
|
|
}
|
|
}
|
|
|
|
func TestRequestClone(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("X-Test") != "value" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
req := NewSimpleRequest(server.URL, "GET").
|
|
SetHeader("X-Test", "value")
|
|
|
|
// 第一次请求
|
|
resp, err := req.Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
resp.Close()
|
|
|
|
// 克隆请求
|
|
cloned := req.Clone()
|
|
cloned.SetHeader("X-Extra", "extra")
|
|
|
|
// 克隆的请求应该也能成功
|
|
resp2, err := cloned.Do()
|
|
if err != nil {
|
|
t.Fatalf("Cloned Do() error: %v", err)
|
|
}
|
|
defer resp2.Close()
|
|
|
|
if resp2.StatusCode != http.StatusOK {
|
|
t.Errorf("Cloned request StatusCode = %v; want %v", resp2.StatusCode, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestRequestContext(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(100 * time.Millisecond)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
req := NewSimpleRequest(server.URL, "GET").SetContext(ctx)
|
|
_, err := req.Do()
|
|
if err == nil {
|
|
t.Error("Expected timeout error, got nil")
|
|
}
|
|
}
|