151 lines
4.0 KiB
Go
151 lines
4.0 KiB
Go
|
|
package starnet
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/tls"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestRequestSetURLDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||
|
|
cfg := &tls.Config{}
|
||
|
|
|
||
|
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||
|
|
SetTLSConfig(cfg).
|
||
|
|
SetURL("https://other.example")
|
||
|
|
|
||
|
|
if req.Err() != nil {
|
||
|
|
t.Fatalf("unexpected request error: %v", req.Err())
|
||
|
|
}
|
||
|
|
if cfg.ServerName != "" {
|
||
|
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestPrepareSetTLSServerNameDoesNotMutateProvidedTLSConfig(t *testing.T) {
|
||
|
|
cfg := &tls.Config{InsecureSkipVerify: true}
|
||
|
|
|
||
|
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||
|
|
SetTLSConfig(cfg).
|
||
|
|
SetTLSServerName("override.example")
|
||
|
|
|
||
|
|
if err := req.prepare(); err != nil {
|
||
|
|
t.Fatalf("prepare error: %v", err)
|
||
|
|
}
|
||
|
|
if cfg.ServerName != "" {
|
||
|
|
t.Fatalf("provided tls.Config was mutated, ServerName=%q", cfg.ServerName)
|
||
|
|
}
|
||
|
|
|
||
|
|
rc := getRequestContext(req.execCtx)
|
||
|
|
if rc.TLSConfig == nil {
|
||
|
|
t.Fatal("expected injected tls config")
|
||
|
|
}
|
||
|
|
if rc.TLSConfig == cfg {
|
||
|
|
t.Fatal("expected injected tls config to be cloned")
|
||
|
|
}
|
||
|
|
if rc.TLSConfig.ServerName != "override.example" {
|
||
|
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestPrepareWithTLSServerNameWithoutTLSConfig(t *testing.T) {
|
||
|
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||
|
|
SetTLSServerName("override.example")
|
||
|
|
|
||
|
|
if err := req.prepare(); err != nil {
|
||
|
|
t.Fatalf("prepare error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
rc := getRequestContext(req.execCtx)
|
||
|
|
if rc.TLSConfig == nil {
|
||
|
|
t.Fatal("expected injected tls config")
|
||
|
|
}
|
||
|
|
if rc.TLSConfig.ServerName != "override.example" {
|
||
|
|
t.Fatalf("injected ServerName=%q", rc.TLSConfig.ServerName)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestPrepareDefaultPathSkipsRequestContextInjection(t *testing.T) {
|
||
|
|
req := NewSimpleRequest("https://example.com", http.MethodGet)
|
||
|
|
|
||
|
|
if err := req.prepare(); err != nil {
|
||
|
|
t.Fatalf("prepare error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := req.execCtx.Value(ctxKeyRequestContext); got != nil {
|
||
|
|
t.Fatalf("unexpected request context injection: %#v", got)
|
||
|
|
}
|
||
|
|
|
||
|
|
rc := getRequestContext(req.execCtx)
|
||
|
|
if needsDynamicTransport(rc) {
|
||
|
|
t.Fatalf("default path unexpectedly marked dynamic: %#v", rc)
|
||
|
|
}
|
||
|
|
if rc.TLSServerName != "" {
|
||
|
|
t.Fatalf("default path unexpectedly injected tls server name: %q", rc.TLSServerName)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestPrepareDynamicPathInjectsAggregatedRequestContext(t *testing.T) {
|
||
|
|
req := NewSimpleRequest("https://example.com", http.MethodGet).
|
||
|
|
SetCustomIP([]string{"127.0.0.1"}).
|
||
|
|
SetSkipTLSVerify(true)
|
||
|
|
|
||
|
|
if err := req.prepare(); err != nil {
|
||
|
|
t.Fatalf("prepare error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
raw := req.execCtx.Value(ctxKeyRequestContext)
|
||
|
|
rc, ok := raw.(*RequestContext)
|
||
|
|
if !ok || rc == nil {
|
||
|
|
t.Fatalf("expected aggregated request context, got %#v", raw)
|
||
|
|
}
|
||
|
|
if len(rc.CustomIP) != 1 || rc.CustomIP[0] != "127.0.0.1" {
|
||
|
|
t.Fatalf("custom ip=%v", rc.CustomIP)
|
||
|
|
}
|
||
|
|
if rc.TLSConfig == nil || !rc.TLSConfig.InsecureSkipVerify {
|
||
|
|
t.Fatal("expected tls config with skip verify")
|
||
|
|
}
|
||
|
|
if rc.TLSServerName != "example.com" {
|
||
|
|
t.Fatalf("default tls server name=%q", rc.TLSServerName)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestSetHostOverridesRequestHost(t *testing.T) {
|
||
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if r.Host != "override.example" {
|
||
|
|
t.Fatalf("host=%q", r.Host)
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
resp, err := NewSimpleRequest(s.URL, http.MethodGet).
|
||
|
|
SetHost("override.example").
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWithHostOverridesRequestHost(t *testing.T) {
|
||
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if r.Host != "option.example" {
|
||
|
|
t.Fatalf("host=%q", r.Host)
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer s.Close()
|
||
|
|
|
||
|
|
resp, err := NewRequest(s.URL, http.MethodGet, WithHost("option.example"))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewRequest() error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
got, err := resp.Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer got.Close()
|
||
|
|
}
|