starnet/review_regression_test.go

245 lines
6.0 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) {
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
firstTarget := net.JoinHostPort("127.0.0.2", port)
secondTarget := net.JoinHostPort("127.0.0.1", port)
var (
mu sync.Mutex
connectTargets []string
)
proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodConnect {
http.Error(w, "connect required", http.StatusMethodNotAllowed)
return
}
mu.Lock()
connectTargets = append(connectTargets, r.Host)
mu.Unlock()
if r.Host == firstTarget {
http.Error(w, "first target failed", http.StatusBadGateway)
return
}
targetConn, err := net.Dial("tcp", r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
targetConn.Close()
t.Fatal("proxy response writer is not a hijacker")
}
clientConn, rw, err := hijacker.Hijack()
if err != nil {
targetConn.Close()
t.Fatalf("hijack proxy conn: %v", err)
}
if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("write connect response: %v", err)
}
if err := rw.Flush(); err != nil {
clientConn.Close()
targetConn.Close()
t.Fatalf("flush connect response: %v", err)
}
relayProxyConns(clientConn, targetConn)
}))
defer proxyServer.Close()
reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if len(connectTargets) != 2 {
t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets)
}
if connectTargets[0] != firstTarget {
t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget)
}
if connectTargets[1] != secondTarget {
t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget)
}
}
func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var (
mu sync.Mutex
dnsStartCount int
dnsDoneCount int
lastHost string
)
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
lastHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
reqURL := "http://localhost:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(reqURL, http.MethodGet).
SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if lastHost != "localhost" {
t.Fatalf("lastHost=%q; want localhost", lastHost)
}
}
func TestRequestHeadersReturnsCopy(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet).
SetHeader("X-Test", "one").
SetHost("origin.example")
headers := req.Headers()
headers.Set("X-Test", "two")
headers.Set("Host", "mutated.example")
if got := req.GetHeader("X-Test"); got != "one" {
t.Fatalf("request header=%q; want one", got)
}
if got := req.Host(); got != "origin.example" {
t.Fatalf("request host=%q; want origin.example", got)
}
}
func TestRequestCookiesIsolation(t *testing.T) {
req := NewSimpleRequest("http://example.com", http.MethodGet)
source := []*http.Cookie{{
Name: "session",
Value: "one",
Path: "/",
}}
req.SetCookies(source)
source[0].Value = "mutated-outside"
got := req.Cookies()
if len(got) != 1 || got[0].Value != "one" {
t.Fatalf("cookies after SetCookies=%v", got)
}
got[0].Value = "mutated-copy"
if latest := req.Cookies()[0].Value; latest != "one" {
t.Fatalf("internal cookie mutated via getter, got %q", latest)
}
cookie := &http.Cookie{Name: "auth", Value: "token"}
req.ResetCookies().AddCookie(cookie)
cookie.Value = "changed"
latest := req.Cookies()
if len(latest) != 1 || latest[0].Value != "token" {
t.Fatalf("cookies after AddCookie=%v", latest)
}
}
func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var dnsStartCount int
var dnsDoneCount int
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
dnsStartCount++
},
DNSDone: func(info TraceDNSDoneInfo) {
dnsDoneCount++
},
}
resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if dnsStartCount != 1 || dnsDoneCount != 1 {
t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount)
}
}