starnet/options_test.go

235 lines
6.1 KiB
Go
Raw Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestWithJSONOpt(t *testing.T) {
type payload struct {
Name string `json:"name"`
Age int `json:"age"`
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
t.Fatalf("content-type=%s", ct)
}
var p payload
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
t.Fatalf("decode err: %v", err)
}
if p.Name != "alice" || p.Age != 18 {
t.Fatalf("payload mismatch: %+v", p)
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileOpt(t *testing.T) {
// temp file + cleanup
f, err := os.CreateTemp("", "starnet-upload-*.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
_, _ = f.WriteString("hello-file")
_ = f.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("file")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename == "" || string(b) != "hello-file" {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFile("file", f.Name()))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithFileStreamOpt(t *testing.T) {
content := "stream-content"
reader := strings.NewReader(content)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatalf("parse form err: %v", err)
}
file, header, err := r.FormFile("up")
if err != nil {
t.Fatalf("form file err: %v", err)
}
defer file.Close()
b, _ := io.ReadAll(file)
if header.Filename != "a.txt" || string(b) != content {
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
}
func TestWithQueryOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("k") != "v" {
t.Fatalf("query mismatch: %v", r.URL.Query())
}
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithQuery("k", "v"))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithUploadProgressOpt(t *testing.T) {
var called int32
var last int64
content := strings.Repeat("x", 4096)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseMultipartForm(10 << 20)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Post(s.URL,
WithUploadProgress(func(filename string, uploaded, total int64) {
atomic.StoreInt32(&called, 1)
last = uploaded
}),
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
)
if err != nil {
t.Fatalf("Post error: %v", err)
}
resp.Close()
if atomic.LoadInt32(&called) == 0 {
t.Fatal("progress not called")
}
if last != int64(len(content)) {
t.Fatalf("last uploaded=%d want=%d", last, len(content))
}
}
func TestWithTransportOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
if err != nil {
t.Fatalf("Get error: %v", err)
}
resp.Close()
}
func TestWithContextOpt(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := Get(s.URL, WithContext(ctx))
if err == nil {
t.Fatal("expected context timeout error")
}
}
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomDNS) != 2 {
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
}
}
func TestWithAddCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
}
}
func TestWithCustomIPOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
if req.Err() != nil {
t.Fatalf("unexpected err: %v", req.Err())
}
if len(req.config.DNS.CustomIP) != 2 {
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
}
}
func TestWithDialFuncOpt(t *testing.T) {
called := int32(0)
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.StoreInt32(&called, 1)
return nil, io.EOF
}
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
if req.config.Network.DialFunc == nil {
t.Fatal("dial func not set")
}
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
if atomic.LoadInt32(&called) == 0 {
t.Fatal("dial func not called")
}
}
func TestWithDialTimeoutOpt(t *testing.T) {
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
if req.config.Network.DialTimeout != 123*time.Millisecond {
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
}
}