package starnet import ( "context" "fmt" "net" "net/http" "strconv" "sync" "testing" "time" ) func TestRequestProxyWithCustomIPFallbackTriesNextResolvedTarget(t *testing.T) { tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer tlsServer.Close() _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) if err != nil { t.Fatalf("split tls server addr: %v", err) } firstTarget := net.JoinHostPort("127.0.0.2", port) secondTarget := net.JoinHostPort("127.0.0.1", port) var ( mu sync.Mutex connectTargets []string ) proxyServer := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodConnect { http.Error(w, "connect required", http.StatusMethodNotAllowed) return } mu.Lock() connectTargets = append(connectTargets, r.Host) mu.Unlock() if r.Host == firstTarget { http.Error(w, "first target failed", http.StatusBadGateway) return } targetConn, err := net.Dial("tcp", r.Host) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } hijacker, ok := w.(http.Hijacker) if !ok { targetConn.Close() t.Fatal("proxy response writer is not a hijacker") } clientConn, rw, err := hijacker.Hijack() if err != nil { targetConn.Close() t.Fatalf("hijack proxy conn: %v", err) } if _, err := rw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { clientConn.Close() targetConn.Close() t.Fatalf("write connect response: %v", err) } if err := rw.Flush(); err != nil { clientConn.Close() targetConn.Close() t.Fatalf("flush connect response: %v", err) } relayProxyConns(clientConn, targetConn) })) defer proxyServer.Close() reqURL := fmt.Sprintf("https://proxy-fallback.test:%s", port) resp, err := NewSimpleRequest(reqURL, http.MethodGet). SetProxy(proxyServer.URL). SetCustomIP([]string{"127.0.0.2", "127.0.0.1"}). SetSkipTLSVerify(true). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if len(connectTargets) != 2 { t.Fatalf("connect target attempts=%d; want 2 (%v)", len(connectTargets), connectTargets) } if connectTargets[0] != firstTarget { t.Fatalf("first connect target=%q; want %q", connectTargets[0], firstTarget) } if connectTargets[1] != secondTarget { t.Fatalf("second connect target=%q; want %q", connectTargets[1], secondTarget) } } func TestTraceHooksDefaultResolverEmitsDNSEvents(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) if err != nil { t.Fatalf("ResolveTCPAddr() error: %v", err) } var ( mu sync.Mutex dnsStartCount int dnsDoneCount int lastHost string ) hooks := &TraceHooks{ DNSStart: func(info TraceDNSStartInfo) { mu.Lock() dnsStartCount++ lastHost = info.Host mu.Unlock() }, DNSDone: func(info TraceDNSDoneInfo) { mu.Lock() dnsDoneCount++ mu.Unlock() if info.Err != nil { t.Errorf("unexpected dns error: %v", info.Err) } }, } reqURL := "http://localhost:" + strconv.Itoa(addr.Port) resp, err := NewSimpleRequest(reqURL, http.MethodGet). SetDialTimeout(DefaultDialTimeout + 200*time.Millisecond). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if dnsStartCount != 1 { t.Fatalf("dnsStartCount=%d", dnsStartCount) } if dnsDoneCount != 1 { t.Fatalf("dnsDoneCount=%d", dnsDoneCount) } if lastHost != "localhost" { t.Fatalf("lastHost=%q; want localhost", lastHost) } } func TestRequestHeadersReturnsCopy(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodGet). SetHeader("X-Test", "one"). SetHost("origin.example") headers := req.Headers() headers.Set("X-Test", "two") headers.Set("Host", "mutated.example") if got := req.GetHeader("X-Test"); got != "one" { t.Fatalf("request header=%q; want one", got) } if got := req.Host(); got != "origin.example" { t.Fatalf("request host=%q; want origin.example", got) } } func TestRequestCookiesIsolation(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodGet) source := []*http.Cookie{{ Name: "session", Value: "one", Path: "/", }} req.SetCookies(source) source[0].Value = "mutated-outside" got := req.Cookies() if len(got) != 1 || got[0].Value != "one" { t.Fatalf("cookies after SetCookies=%v", got) } got[0].Value = "mutated-copy" if latest := req.Cookies()[0].Value; latest != "one" { t.Fatalf("internal cookie mutated via getter, got %q", latest) } cookie := &http.Cookie{Name: "auth", Value: "token"} req.ResetCookies().AddCookie(cookie) cookie.Value = "changed" latest := req.Cookies() if len(latest) != 1 || latest[0].Value != "token" { t.Fatalf("cookies after AddCookie=%v", latest) } } func TestTraceHooksLookupFuncStillEmitsDNSEvents(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) if err != nil { t.Fatalf("ResolveTCPAddr() error: %v", err) } var dnsStartCount int var dnsDoneCount int hooks := &TraceHooks{ DNSStart: func(info TraceDNSStartInfo) { dnsStartCount++ }, DNSDone: func(info TraceDNSDoneInfo) { dnsDoneCount++ }, } resp, err := NewSimpleRequest("http://lookup-copy.test:"+strconv.Itoa(addr.Port), http.MethodGet). SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: addr.IP}}, nil }). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if dnsStartCount != 1 || dnsDoneCount != 1 { t.Fatalf("dns trace counts start=%d done=%d", dnsStartCount, dnsDoneCount) } }