package starnet import ( "crypto/tls" "net" "net/http" "strconv" "sync" "testing" "time" ) func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) { transport := &Transport{base: newBaseHTTPTransport()} first := transport.getDynamicTransport(&RequestContext{ Proxy: "http://127.0.0.1:8080", DialTimeout: 2 * time.Second, CustomIP: []string{"127.0.0.1"}, TLSServerName: "cache.test", }, nil) second := transport.getDynamicTransport(&RequestContext{ Proxy: "http://127.0.0.1:8080", DialTimeout: 2 * time.Second, CustomIP: []string{"127.0.0.1"}, TLSServerName: "cache.test", }, nil) if first != second { t.Fatal("expected cached dynamic transport to be reused") } if got := len(transport.dynamicCache); got != 1 { t.Fatalf("dynamic cache size=%d; want 1", got) } } func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) { transport := &Transport{base: newBaseHTTPTransport()} first := transport.getDynamicTransport(&RequestContext{ CustomIP: []string{"127.0.0.1"}, TLSServerName: "first.test", }, nil) second := transport.getDynamicTransport(&RequestContext{ CustomIP: []string{"127.0.0.1"}, TLSServerName: "second.test", }, nil) if first == second { t.Fatal("expected distinct tls server names to use different transports") } if got := len(transport.dynamicCache); got != 2 { t.Fatalf("dynamic cache size=%d; want 2", got) } } func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) { transport := &Transport{base: newBaseHTTPTransport()} reqCtx := &RequestContext{ CustomIP: []string{"127.0.0.1"}, TLSConfig: &tls.Config{InsecureSkipVerify: true}, } first := transport.getDynamicTransport(reqCtx, nil) second := transport.getDynamicTransport(reqCtx, nil) if first == second { t.Fatal("expected user tls config to bypass dynamic transport cache") } if got := len(transport.dynamicCache); got != 0 { t.Fatalf("dynamic cache size=%d; want 0", got) } } func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) { client := NewClientNoErr() transport, ok := client.HTTPClient().Transport.(*Transport) if !ok { t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport) } reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}} first := transport.getDynamicTransport(reqCtx, nil) if got := len(transport.dynamicCache); got != 1 { t.Fatalf("dynamic cache size=%d; want 1 before reset", got) } client.SetDefaultSkipTLSVerify(true) if got := len(transport.dynamicCache); got != 0 { t.Fatalf("dynamic cache size=%d; want 0 after reset", got) } second := transport.getDynamicTransport(reqCtx, nil) if first == second { t.Fatal("expected cache reset after default tls change") } } func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) if err != nil { t.Fatalf("ResolveTCPAddr() error: %v", err) } client := NewClientNoErr() targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port) runRequest := func() bool { var ( mu sync.Mutex gotConn bool reused bool ) resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). SetCustomIP([]string{"127.0.0.1"}). SetTraceHooks(&TraceHooks{ GotConn: func(info TraceGotConnInfo) { mu.Lock() gotConn = true reused = info.Reused mu.Unlock() }, }). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } mu.Lock() defer mu.Unlock() if !gotConn { t.Fatal("expected GotConn trace event") } return reused } if runRequest() { t.Fatal("first request unexpectedly reused a connection") } if !runRequest() { t.Fatal("second request did not reuse cached dynamic transport connection") } transport, ok := client.HTTPClient().Transport.(*Transport) if !ok { t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport) } if got := len(transport.dynamicCache); got != 1 { t.Fatalf("dynamic cache size=%d; want 1", got) } } func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil) if err != nil { t.Fatalf("http.NewRequest() error: %v", err) } req.Host = req.URL.Host execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{ Proxy: "http://127.0.0.1:8080", CustomIP: []string{"127.0.0.1"}, }, nil) if err != nil { t.Fatalf("prepareProxyTargetRequest() error: %v", err) } if execReq == req { t.Fatal("expected cloned request for proxy target preparation") } if got := execReq.URL.Host; got != "127.0.0.1:8443" { t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443") } if got := req.URL.Host; got != "proxy-single.test:8443" { t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443") } if len(targetAddrs) != 0 { t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs) } if execReqCtx == nil || execReqCtx.TLSConfig == nil { t.Fatal("expected synthesized tls config for single target proxy request") } if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" { t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test") } } func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil) if err != nil { t.Fatalf("http.NewRequest() error: %v", err) } req.Host = req.URL.Host execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{ Proxy: "http://127.0.0.1:8080", CustomIP: []string{"127.0.0.1", "127.0.0.2"}, }, nil) if err != nil { t.Fatalf("prepareProxyTargetRequest() error: %v", err) } if got := execReq.URL.Host; got != "proxy-multi.test:9443" { t.Fatalf("execReq.URL.Host=%q; want original host", got) } if len(targetAddrs) != 2 { t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs) } if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" { t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs) } }