235 lines
6.1 KiB
Go
235 lines
6.1 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestWithJSONOpt(t *testing.T) {
|
|
type payload struct {
|
|
Name string `json:"name"`
|
|
Age int `json:"age"`
|
|
}
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if ct := r.Header.Get("Content-Type"); ct != ContentTypeJSON {
|
|
t.Fatalf("content-type=%s", ct)
|
|
}
|
|
var p payload
|
|
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
|
|
t.Fatalf("decode err: %v", err)
|
|
}
|
|
if p.Name != "alice" || p.Age != 18 {
|
|
t.Fatalf("payload mismatch: %+v", p)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Post(s.URL, WithJSON(payload{Name: "alice", Age: 18}))
|
|
if err != nil {
|
|
t.Fatalf("Post error: %v", err)
|
|
}
|
|
resp.Close()
|
|
}
|
|
|
|
func TestWithFileOpt(t *testing.T) {
|
|
// temp file + cleanup
|
|
f, err := os.CreateTemp("", "starnet-upload-*.txt")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.Remove(f.Name())
|
|
_, _ = f.WriteString("hello-file")
|
|
_ = f.Close()
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
|
t.Fatalf("parse form err: %v", err)
|
|
}
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
t.Fatalf("form file err: %v", err)
|
|
}
|
|
defer file.Close()
|
|
b, _ := io.ReadAll(file)
|
|
if header.Filename == "" || string(b) != "hello-file" {
|
|
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Post(s.URL, WithFile("file", f.Name()))
|
|
if err != nil {
|
|
t.Fatalf("Post error: %v", err)
|
|
}
|
|
resp.Close()
|
|
}
|
|
|
|
func TestWithFileStreamOpt(t *testing.T) {
|
|
content := "stream-content"
|
|
reader := strings.NewReader(content)
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
|
t.Fatalf("parse form err: %v", err)
|
|
}
|
|
file, header, err := r.FormFile("up")
|
|
if err != nil {
|
|
t.Fatalf("form file err: %v", err)
|
|
}
|
|
defer file.Close()
|
|
b, _ := io.ReadAll(file)
|
|
if header.Filename != "a.txt" || string(b) != content {
|
|
t.Fatalf("upload mismatch filename=%q body=%q", header.Filename, string(b))
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Post(s.URL, WithFileStream("up", "a.txt", int64(len(content)), reader))
|
|
if err != nil {
|
|
t.Fatalf("Post error: %v", err)
|
|
}
|
|
resp.Close()
|
|
}
|
|
|
|
func TestWithQueryOpt(t *testing.T) {
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("k") != "v" {
|
|
t.Fatalf("query mismatch: %v", r.URL.Query())
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Get(s.URL, WithQuery("k", "v"))
|
|
if err != nil {
|
|
t.Fatalf("Get error: %v", err)
|
|
}
|
|
resp.Close()
|
|
}
|
|
|
|
func TestWithUploadProgressOpt(t *testing.T) {
|
|
var called int32
|
|
var last int64
|
|
content := strings.Repeat("x", 4096)
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = r.ParseMultipartForm(10 << 20)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Post(s.URL,
|
|
WithUploadProgress(func(filename string, uploaded, total int64) {
|
|
atomic.StoreInt32(&called, 1)
|
|
last = uploaded
|
|
}),
|
|
WithFileStream("f", "p.txt", int64(len(content)), strings.NewReader(content)),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Post error: %v", err)
|
|
}
|
|
resp.Close()
|
|
|
|
if atomic.LoadInt32(&called) == 0 {
|
|
t.Fatal("progress not called")
|
|
}
|
|
if last != int64(len(content)) {
|
|
t.Fatalf("last uploaded=%d want=%d", last, len(content))
|
|
}
|
|
}
|
|
|
|
func TestWithTransportOpt(t *testing.T) {
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
resp, err := Get(s.URL, WithTransport(&http.Transport{}))
|
|
if err != nil {
|
|
t.Fatalf("Get error: %v", err)
|
|
}
|
|
resp.Close()
|
|
}
|
|
|
|
func TestWithContextOpt(t *testing.T) {
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(200 * time.Millisecond)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer s.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := Get(s.URL, WithContext(ctx))
|
|
if err == nil {
|
|
t.Fatal("expected context timeout error")
|
|
}
|
|
}
|
|
|
|
func TestWithCustomDNSOpt_ConfigApplied(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", "GET", WithCustomDNS([]string{"8.8.8.8", "1.1.1.1"}))
|
|
if req.Err() != nil {
|
|
t.Fatalf("unexpected err: %v", req.Err())
|
|
}
|
|
if len(req.config.DNS.CustomDNS) != 2 {
|
|
t.Fatalf("custom dns len=%d", len(req.config.DNS.CustomDNS))
|
|
}
|
|
}
|
|
|
|
func TestWithAddCustomIPOpt(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", "GET", WithAddCustomIP("1.2.3.4"))
|
|
if req.Err() != nil {
|
|
t.Fatalf("unexpected err: %v", req.Err())
|
|
}
|
|
if len(req.config.DNS.CustomIP) != 1 || req.config.DNS.CustomIP[0] != "1.2.3.4" {
|
|
t.Fatalf("custom ip mismatch: %v", req.config.DNS.CustomIP)
|
|
}
|
|
}
|
|
|
|
func TestWithCustomIPOpt(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", "GET", WithCustomIP([]string{"1.1.1.1", "8.8.8.8"}))
|
|
if req.Err() != nil {
|
|
t.Fatalf("unexpected err: %v", req.Err())
|
|
}
|
|
if len(req.config.DNS.CustomIP) != 2 {
|
|
t.Fatalf("custom ip len=%d", len(req.config.DNS.CustomIP))
|
|
}
|
|
}
|
|
|
|
func TestWithDialFuncOpt(t *testing.T) {
|
|
called := int32(0)
|
|
fn := func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
atomic.StoreInt32(&called, 1)
|
|
return nil, io.EOF
|
|
}
|
|
req := NewSimpleRequest("http://example.com", "GET", WithDialFunc(fn))
|
|
if req.config.Network.DialFunc == nil {
|
|
t.Fatal("dial func not set")
|
|
}
|
|
_, _ = req.config.Network.DialFunc(context.Background(), "tcp", "x:1")
|
|
if atomic.LoadInt32(&called) == 0 {
|
|
t.Fatal("dial func not called")
|
|
}
|
|
}
|
|
|
|
func TestWithDialTimeoutOpt(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", "GET", WithDialTimeout(123*time.Millisecond))
|
|
if req.config.Network.DialTimeout != 123*time.Millisecond {
|
|
t.Fatalf("dial timeout=%v", req.config.Network.DialTimeout)
|
|
}
|
|
}
|