starnet/transport_cache_test.go

225 lines
6.4 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"crypto/tls"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
)
func TestTransportDynamicCacheReusesSafeProfile(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
Proxy: "http://127.0.0.1:8080",
DialTimeout: 2 * time.Second,
CustomIP: []string{"127.0.0.1"},
TLSServerName: "cache.test",
}, nil)
if first != second {
t.Fatal("expected cached dynamic transport to be reused")
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestTransportDynamicCacheSeparatesTLSServerName(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
first := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "first.test",
}, nil)
second := transport.getDynamicTransport(&RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSServerName: "second.test",
}, nil)
if first == second {
t.Fatal("expected distinct tls server names to use different transports")
}
if got := len(transport.dynamicCache); got != 2 {
t.Fatalf("dynamic cache size=%d; want 2", got)
}
}
func TestTransportDynamicCacheSkipsUserTLSConfig(t *testing.T) {
transport := &Transport{base: newBaseHTTPTransport()}
reqCtx := &RequestContext{
CustomIP: []string{"127.0.0.1"},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
first := transport.getDynamicTransport(reqCtx, nil)
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected user tls config to bypass dynamic transport cache")
}
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0", got)
}
}
func TestTransportDynamicCacheResetOnDefaultTLSChange(t *testing.T) {
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
reqCtx := &RequestContext{CustomIP: []string{"127.0.0.1"}}
first := transport.getDynamicTransport(reqCtx, nil)
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1 before reset", got)
}
client.SetDefaultSkipTLSVerify(true)
if got := len(transport.dynamicCache); got != 0 {
t.Fatalf("dynamic cache size=%d; want 0 after reset", got)
}
second := transport.getDynamicTransport(reqCtx, nil)
if first == second {
t.Fatal("expected cache reset after default tls change")
}
}
func TestDynamicTransportCacheReusesConnectionForCustomIP(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
client := NewClientNoErr()
targetURL := "http://cache-reuse.test:" + strconv.Itoa(addr.Port)
runRequest := func() bool {
var (
mu sync.Mutex
gotConn bool
reused bool
)
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetCustomIP([]string{"127.0.0.1"}).
SetTraceHooks(&TraceHooks{
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
gotConn = true
reused = info.Reused
mu.Unlock()
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
mu.Lock()
defer mu.Unlock()
if !gotConn {
t.Fatal("expected GotConn trace event")
}
return reused
}
if runRequest() {
t.Fatal("first request unexpectedly reused a connection")
}
if !runRequest() {
t.Fatal("second request did not reuse cached dynamic transport connection")
}
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T; want *Transport", client.HTTPClient().Transport)
}
if got := len(transport.dynamicCache); got != 1 {
t.Fatalf("dynamic cache size=%d; want 1", got)
}
}
func TestPrepareProxyTargetRequestSingleTargetRewritesExecRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-single.test:8443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, execReqCtx, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if execReq == req {
t.Fatal("expected cloned request for proxy target preparation")
}
if got := execReq.URL.Host; got != "127.0.0.1:8443" {
t.Fatalf("execReq.URL.Host=%q; want %q", got, "127.0.0.1:8443")
}
if got := req.URL.Host; got != "proxy-single.test:8443" {
t.Fatalf("original req.URL.Host=%q; want %q", got, "proxy-single.test:8443")
}
if len(targetAddrs) != 0 {
t.Fatalf("targetAddrs=%v; want empty after single target rewrite", targetAddrs)
}
if execReqCtx == nil || execReqCtx.TLSConfig == nil {
t.Fatal("expected synthesized tls config for single target proxy request")
}
if got := execReqCtx.TLSConfig.ServerName; got != "proxy-single.test" {
t.Fatalf("tls server name=%q; want %q", got, "proxy-single.test")
}
}
func TestPrepareProxyTargetRequestMultiTargetPreservesFallbackList(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://proxy-multi.test:9443/path", nil)
if err != nil {
t.Fatalf("http.NewRequest() error: %v", err)
}
req.Host = req.URL.Host
execReq, _, targetAddrs, err := prepareProxyTargetRequest(req, &RequestContext{
Proxy: "http://127.0.0.1:8080",
CustomIP: []string{"127.0.0.1", "127.0.0.2"},
}, nil)
if err != nil {
t.Fatalf("prepareProxyTargetRequest() error: %v", err)
}
if got := execReq.URL.Host; got != "proxy-multi.test:9443" {
t.Fatalf("execReq.URL.Host=%q; want original host", got)
}
if len(targetAddrs) != 2 {
t.Fatalf("targetAddrs=%v; want 2 targets", targetAddrs)
}
if targetAddrs[0] != "127.0.0.1:9443" || targetAddrs[1] != "127.0.0.2:9443" {
t.Fatalf("targetAddrs=%v; want ordered fallback targets", targetAddrs)
}
}