package starnet import ( "context" "encoding/json" "io" "net" "net/http" "net/http/httptest" "os" "strings" "sync/atomic" "testing" "time" ) func TestWithJSONOpt(t *testing.T) { type payload struct { Name string `json:"name"` Age int `json:"age"` } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON { t.Fatalf("content-type=%s", ct) } var p payload if err := json.NewDecoder(r.Body).Decode(&p); err != nil { t.Fatalf("decode err: %v", err) } if p.Name != "alice" || p.Age != 18 { t.Fatalf("payload mismatch: %+v", p) } w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18})) if err != nil { t.Fatalf("Post error: %v", err) } resp.Close() } func TestWithFileOpt(t *testing.T) { // temp file + cleanup f, err := os.CreateTemp("", "starnet-upload-*.txt") if err != nil { t.Fatal(err) } defer os.Remove(f.Name()) _, _ = f.WriteString("hello-file") _ = f.Close() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(10 << 20); err != nil { t.Fatalf("parse form err: %v", err) } file, header, err := r.FormFile("file") if err != nil { t.Fatalf("form file err: %v", err) } defer file.Close() b, _ := io.ReadAll(file) if header.Filename == "" || string(b) != "hello-file" { t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b)) } w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Post(s.URL, WithFile("file", f.Name())) if err != nil { t.Fatalf("Post error: %v", err) } resp.Close() } func TestWithFileStreamOpt(t *testing.T) { content := "stream-content" reader := strings.NewReader(content) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(10 << 20); err != nil { t.Fatalf("parse form err: %v", err) } file, header, err := r.FormFile("up") if err != nil { t.Fatalf("form file err: %v", err) } defer file.Close() b, _ := io.ReadAll(file) if header.Filename != "a.txt" || string(b) != content { t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b)) } w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader)) if err != nil { t.Fatalf("Post error: %v", err) } resp.Close() } func TestWithQueryOpt(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("k") != "v" { t.Fatalf("query mismatch: %v", r.URL.Query()) } w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Get(s.URL, WithQuery("k", "v")) if err != nil { t.Fatalf("Get error: %v", err) } resp.Close() } func TestWithUploadProgressOpt(t *testing.T) { var called int32 var last int64 content := strings.Repeat("x", 4096) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseMultipartForm(10 << 20) w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Post(s.URL, WithUploadProgress(func(filename string, uploaded, total int64) { atomic.StoreInt32(&called, 1) last = uploaded }), WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)), ) if err != nil { t.Fatalf("Post error: %v", err) } resp.Close() if atomic.LoadInt32(&called) == 0 { t.Fatal("progress not called") } if last != int64(len(content)) { t.Fatalf("last uploaded=%d want=%d", last, len(content)) } } func TestWithTransportOpt(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer s.Close() resp, err := Get(s.URL, WithTransport(&http.Transport{})) if err != nil { t.Fatalf("Get error: %v", err) } resp.Close() } func TestWithContextOpt(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) w.WriteHeader(http.StatusOK) })) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() _, err := Get(s.URL, WithContext(ctx)) if err == nil { t.Fatal("expected context timeout error") } } func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"})) if req.Err() != nil { t.Fatalf("unexpected err: %v", req.Err()) } if len(req.config.DNS.CustomDNS) != 2 { t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS)) } } func TestWithAddCustomIPOpt(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4")) if req.Err() != nil { t.Fatalf("unexpected err: %v", req.Err()) } if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" { t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP) } } func TestWithCustomIPOpt(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"})) if req.Err() != nil { t.Fatalf("unexpected err: %v", req.Err()) } if len(req.config.DNS.CustomIP) != 2 { t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP)) } } func TestWithDialFuncOpt(t *testing.T) { called := int32(0) fn := func(ctx context.Context, network, addr string) (net.Conn, error) { atomic.StoreInt32(&called, 1) return nil, io.EOF } req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn)) if req.config.Network.DialFunc == nil { t.Fatal("dial func not set") } _, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1") if atomic.LoadInt32(&called) == 0 { t.Fatal("dial func not called") } } func TestWithDialTimeoutOpt(t *testing.T) { req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond)) if req.config.Network.DialTimeout != 123*time.Millisecond { t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout) } }