- 分离 Request 的配置态与执行态,修复二次 Do、raw 模式网络配置失效和 body 来源互斥问题 - 新增 starnet trace 抽象,补齐 DNS/连接/TLS/重试事件,并优化动态 transport 缓存与代理解析路径 - 收紧非法代理为 fail-fast,多目标目标回退仅限幂等请求,修复 Host/TLS/SNI 等语义边界 - 补充防御性拷贝、专项回归测试、本地代理/TLS 用例与 README 行为说明
245 lines
6.0 KiB
Go
245 lines
6.0 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
|
|
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer tlsServer.Close()
|
|
|
|
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("split tls server addr: %v", err)
|
|
}
|
|
|
|
firstTarget := net.JoinHostPort("127.0.0.2", port)
|
|
secondTarget := net.JoinHostPort("127.0.0.1", port)
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
connectTargets []string
|
|
)
|
|
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodConnect {
|
|
http.Error(w, "connect required", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
mu.Lock()
|
|
connectTargets = append(connectTargets, r.Host)
|
|
mu.Unlock()
|
|
|
|
if r.Host == firstTarget {
|
|
http.Error(w, "first target failed", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
targetConn, err := net.Dial("tcp", r.Host)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
targetConn.Close()
|
|
t.Fatal("proxy response writer is not a hijacker")
|
|
}
|
|
|
|
clientConn, rw, err := hijacker.Hijack()
|
|
if err != nil {
|
|
targetConn.Close()
|
|
t.Fatalf("hijack proxy conn: %v", err)
|
|
}
|
|
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
|
clientConn.Close()
|
|
targetConn.Close()
|
|
t.Fatalf("write connect response: %v", err)
|
|
}
|
|
if err := rw.Flush(); err != nil {
|
|
clientConn.Close()
|
|
targetConn.Close()
|
|
t.Fatalf("flush connect response: %v", err)
|
|
}
|
|
|
|
relayProxyConns(clientConn, targetConn)
|
|
}))
|
|
defer proxyServer.Close()
|
|
|
|
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
|
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
|
SetProxy(proxyServer.URL).
|
|
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
|
|
SetSkipTLSVerify(true).
|
|
Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if len(connectTargets) != 2 {
|
|
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
|
|
}
|
|
if connectTargets[0] != firstTarget {
|
|
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
|
|
}
|
|
if connectTargets[1] != secondTarget {
|
|
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
|
|
}
|
|
}
|
|
|
|
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
|
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
dnsStartCount int
|
|
dnsDoneCount int
|
|
lastHost string
|
|
)
|
|
hooks := &TraceHooks{
|
|
DNSStart: func(info TraceDNSStartInfo) {
|
|
mu.Lock()
|
|
dnsStartCount++
|
|
lastHost = info.Host
|
|
mu.Unlock()
|
|
},
|
|
DNSDone: func(info TraceDNSDoneInfo) {
|
|
mu.Lock()
|
|
dnsDoneCount++
|
|
mu.Unlock()
|
|
if info.Err != nil {
|
|
t.Errorf("unexpected dns error: %v", info.Err)
|
|
}
|
|
},
|
|
}
|
|
|
|
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
|
|
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
|
|
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
|
|
SetTraceHooks(hooks).
|
|
Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if dnsStartCount != 1 {
|
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
|
}
|
|
if dnsDoneCount != 1 {
|
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
|
}
|
|
if lastHost != "localhost" {
|
|
t.Fatalf("lastHost=%q; want localhost", lastHost)
|
|
}
|
|
}
|
|
|
|
func TestRequestHeadersReturnsCopy(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", http.MethodGet).
|
|
SetHeader("X-Test", "one").
|
|
SetHost("origin.example")
|
|
|
|
headers := req.Headers()
|
|
headers.Set("X-Test", "two")
|
|
headers.Set("Host", "mutated.example")
|
|
|
|
if got := req.GetHeader("X-Test"); got != "one" {
|
|
t.Fatalf("request header=%q; want one", got)
|
|
}
|
|
if got := req.Host(); got != "origin.example" {
|
|
t.Fatalf("request host=%q; want origin.example", got)
|
|
}
|
|
}
|
|
|
|
func TestRequestCookiesIsolation(t *testing.T) {
|
|
req := NewSimpleRequest("http://example.com", http.MethodGet)
|
|
source := []*http.Cookie{{
|
|
Name: "session",
|
|
Value: "one",
|
|
Path: "/",
|
|
}}
|
|
|
|
req.SetCookies(source)
|
|
source[0].Value = "mutated-outside"
|
|
|
|
got := req.Cookies()
|
|
if len(got) != 1 || got[0].Value != "one" {
|
|
t.Fatalf("cookies after SetCookies=%v", got)
|
|
}
|
|
|
|
got[0].Value = "mutated-copy"
|
|
if latest := req.Cookies()[0].Value; latest != "one" {
|
|
t.Fatalf("internal cookie mutated via getter, got %q", latest)
|
|
}
|
|
|
|
cookie := &http.Cookie{Name: "auth", Value: "token"}
|
|
req.ResetCookies().AddCookie(cookie)
|
|
cookie.Value = "changed"
|
|
|
|
latest := req.Cookies()
|
|
if len(latest) != 1 || latest[0].Value != "token" {
|
|
t.Fatalf("cookies after AddCookie=%v", latest)
|
|
}
|
|
}
|
|
|
|
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
|
|
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
|
}
|
|
|
|
var dnsStartCount int
|
|
var dnsDoneCount int
|
|
hooks := &TraceHooks{
|
|
DNSStart: func(info TraceDNSStartInfo) {
|
|
dnsStartCount++
|
|
},
|
|
DNSDone: func(info TraceDNSDoneInfo) {
|
|
dnsDoneCount++
|
|
},
|
|
}
|
|
|
|
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
|
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
|
}).
|
|
SetTraceHooks(hooks).
|
|
Do()
|
|
if err != nil {
|
|
t.Fatalf("Do() error: %v", err)
|
|
}
|
|
defer resp.Close()
|
|
|
|
if dnsStartCount != 1 || dnsDoneCount != 1 {
|
|
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
|
|
}
|
|
}
|