starnet/proxy_custom_ip_test.go

111 lines
2.7 KiB
Go
Raw Normal View History

package starnet
import (
"fmt"
"net"
"net/http"
"testing"
)
func TestRequestProxyWithCustomIPTargetsOriginWithoutRewritingProxyDial(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
proxyServer := newIPv4ConnectProxyServer(t, nil)
defer proxyServer.Close()
targetHost := "proxy-custom-ip.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetProxy(proxyServer.URL).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
targets := proxyServer.Targets()
if len(targets) != 1 {
t.Fatalf("connect targets=%v; want 1 target", targets)
}
gotConnectTarget := targets[0]
wantConnectTarget := net.JoinHostPort("127.0.0.1", port)
if gotConnectTarget != wantConnectTarget {
t.Fatalf("CONNECT target = %q; want %q", gotConnectTarget, wantConnectTarget)
}
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}
func TestRequestCustomIPPreservesOriginalHostAndSNI(t *testing.T) {
tlsReqInfo := make(chan struct {
host string
sni string
}, 1)
tlsServer := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tlsReqInfo <- struct {
host string
sni string
}{
host: r.Host,
sni: r.TLS.ServerName,
}
_, _ = w.Write([]byte("ok"))
}))
defer tlsServer.Close()
_, port, err := net.SplitHostPort(tlsServer.Listener.Addr().String())
if err != nil {
t.Fatalf("split tls server addr: %v", err)
}
targetHost := "custom-ip-direct.test"
reqURL := fmt.Sprintf("https://%s:%s", targetHost, port)
req := NewSimpleRequest(reqURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetSkipTLSVerify(true)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do error: %v", err)
}
defer resp.Close()
gotTLS := <-tlsReqInfo
wantHost := net.JoinHostPort(targetHost, port)
if gotTLS.host != wantHost {
t.Fatalf("request host = %q; want %q", gotTLS.host, wantHost)
}
if gotTLS.sni != targetHost {
t.Fatalf("tls sni = %q; want %q", gotTLS.sni, targetHost)
}
}