package starnet import ( "fmt" "net" "net/http" "testing" ) func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) { tlsReqInfo := make(chan struct { host string sni string }, 1) tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tlsReqInfo <- struct { host string sni string }{ host: r.Host, sni: r.TLS.ServerName, } _, _ = w.Write([]byte("ok")) })) defer tlsServer.Close() _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) if err != nil { t.Fatalf("split tls server addr: %v", err) } proxyServer := newIPv4ConnectProxyServer(t, nil) defer proxyServer.Close() targetHost := "proxy-custom-ip.test" reqURL := fmt.Sprintf("https://%s:%s", targetHost, port) req := NewSimpleRequest(reqURL, http.MethodGet). SetProxy(proxyServer.URL). SetCustomIP([]string{"127.0.0.1"}). SetSkipTLSVerify(true) resp, err := req.Do() if err != nil { t.Fatalf("Do error: %v", err) } defer resp.Close() targets := proxyServer.Targets() if len(targets) != 1 { t.Fatalf("connect targets=%v; want 1 target", targets) } gotConnectTarget := targets[0] wantConnectTarget := net.JoinHostPort("127.0.0.1", port) if gotConnectTarget != wantConnectTarget { t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget) } gotTLS := <-tlsReqInfo wantHost := net.JoinHostPort(targetHost, port) if gotTLS.host != wantHost { t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost) } if gotTLS.sni != targetHost { t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost) } } func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) { tlsReqInfo := make(chan struct { host string sni string }, 1) tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tlsReqInfo <- struct { host string sni string }{ host: r.Host, sni: r.TLS.ServerName, } _, _ = w.Write([]byte("ok")) })) defer tlsServer.Close() _, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String()) if err != nil { t.Fatalf("split tls server addr: %v", err) } targetHost := "custom-ip-direct.test" reqURL := fmt.Sprintf("https://%s:%s", targetHost, port) req := NewSimpleRequest(reqURL, http.MethodGet). SetCustomIP([]string{"127.0.0.1"}). SetSkipTLSVerify(true) resp, err := req.Do() if err != nil { t.Fatalf("Do error: %v", err) } defer resp.Close() gotTLS := <-tlsReqInfo wantHost := net.JoinHostPort(targetHost, port) if gotTLS.host != wantHost { t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost) } if gotTLS.sni != targetHost { t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost) } }