starnet/trace_test.go

325 lines
7.3 KiB
Go
Raw Normal View History

package starnet
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
)
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
var mu sync.Mutex
events := map[string]int{}
hooks := &TraceHooks{
GetConn: func(info TraceGetConnInfo) {
mu.Lock()
events["get_conn"]++
mu.Unlock()
},
GotConn: func(info TraceGotConnInfo) {
mu.Lock()
events["got_conn"]++
mu.Unlock()
},
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
events["tls_start"]++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
events["tls_done"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected tls handshake error: %v", info.Err)
}
},
WroteHeaders: func() {
mu.Lock()
events["wrote_headers"]++
mu.Unlock()
},
WroteRequest: func(info TraceWroteRequestInfo) {
mu.Lock()
events["wrote_request"]++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected write error: %v", info.Err)
}
},
GotFirstResponseByte: func() {
mu.Lock()
events["first_byte"]++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
if events[key] == 0 {
t.Fatalf("expected trace event %q", key)
}
}
}
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
tlsStartCount := 0
tlsDoneCount := 0
var lastInfo TraceTLSHandshakeDoneInfo
hooks := &TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
mu.Lock()
tlsStartCount++
mu.Unlock()
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
mu.Lock()
tlsDoneCount++
lastInfo = info
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetDialTimeout(1500 * time.Millisecond).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if tlsStartCount != 1 {
t.Fatalf("tlsStartCount=%d", tlsStartCount)
}
if tlsDoneCount != 1 {
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
}
if lastInfo.Err != nil {
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
}
if lastInfo.ConnectionState.Version == 0 {
t.Fatal("expected tls connection state")
}
}
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
var mu sync.Mutex
dnsStartCount := 0
dnsDoneCount := 0
var dnsStartHost string
hooks := &TraceHooks{
DNSStart: func(info TraceDNSStartInfo) {
mu.Lock()
dnsStartCount++
dnsStartHost = info.Host
mu.Unlock()
},
DNSDone: func(info TraceDNSDoneInfo) {
mu.Lock()
dnsDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected dns error: %v", info.Err)
}
},
}
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(url, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if dnsStartCount != 1 {
t.Fatalf("dnsStartCount=%d", dnsStartCount)
}
if dnsDoneCount != 1 {
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
}
if dnsStartHost != "trace.example.test" {
t.Fatalf("dnsStartHost=%q", dnsStartHost)
}
}
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
connectStartCount := 0
connectDoneCount := 0
hooks := &TraceHooks{
ConnectStart: func(info TraceConnectStartInfo) {
mu.Lock()
connectStartCount++
mu.Unlock()
},
ConnectDone: func(info TraceConnectDoneInfo) {
mu.Lock()
connectDoneCount++
mu.Unlock()
if info.Err != nil {
t.Errorf("unexpected connect error: %v", info.Err)
}
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
return dialer.DialContext(context.Background(), network, addr)
}).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if connectStartCount != 1 {
t.Fatalf("connectStartCount=%d", connectStartCount)
}
if connectDoneCount != 1 {
t.Fatalf("connectDoneCount=%d", connectDoneCount)
}
}
func TestTraceHooksRetryEvents(t *testing.T) {
var hits int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
if hits == 1 {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
var mu sync.Mutex
starts := 0
dones := 0
backoffs := 0
var finalDone TraceRetryAttemptDoneInfo
hooks := &TraceHooks{
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
mu.Lock()
starts++
mu.Unlock()
},
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
mu.Lock()
dones++
finalDone = info
mu.Unlock()
},
RetryBackoff: func(info TraceRetryBackoffInfo) {
mu.Lock()
backoffs++
mu.Unlock()
},
}
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
SetTraceHooks(hooks).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
mu.Lock()
defer mu.Unlock()
if starts != 2 {
t.Fatalf("starts=%d", starts)
}
if dones != 2 {
t.Fatalf("dones=%d", dones)
}
if backoffs != 1 {
t.Fatalf("backoffs=%d", backoffs)
}
if finalDone.WillRetry {
t.Fatal("expected final attempt not to retry")
}
if finalDone.StatusCode != http.StatusOK {
t.Fatalf("final status=%d", finalDone.StatusCode)
}
}
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
var gotErr error
hooks := &TraceHooks{
DNSDone: func(info TraceDNSDoneInfo) {
gotErr = info.Err
},
}
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return nil, errors.New("lookup failed")
}).
SetTraceHooks(hooks).
Do()
if err == nil {
t.Fatal("expected request error")
}
if gotErr == nil || gotErr.Error() != "lookup failed" {
t.Fatalf("gotErr=%v", gotErr)
}
}