325 lines
7.3 KiB
Go
325 lines
7.3 KiB
Go
|
|
package starnet
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"strconv"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestTraceHooksStandardHTTPSPath(t *testing.T) {
|
||
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
_, _ = w.Write([]byte("ok"))
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
var mu sync.Mutex
|
||
|
|
events := map[string]int{}
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
GetConn: func(info TraceGetConnInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
events["get_conn"]++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
GotConn: func(info TraceGotConnInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
events["got_conn"]++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
events["tls_start"]++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
events["tls_done"]++
|
||
|
|
mu.Unlock()
|
||
|
|
if info.Err != nil {
|
||
|
|
t.Errorf("unexpected tls handshake error: %v", info.Err)
|
||
|
|
}
|
||
|
|
},
|
||
|
|
WroteHeaders: func() {
|
||
|
|
mu.Lock()
|
||
|
|
events["wrote_headers"]++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
WroteRequest: func(info TraceWroteRequestInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
events["wrote_request"]++
|
||
|
|
mu.Unlock()
|
||
|
|
if info.Err != nil {
|
||
|
|
t.Errorf("unexpected write error: %v", info.Err)
|
||
|
|
}
|
||
|
|
},
|
||
|
|
GotFirstResponseByte: func() {
|
||
|
|
mu.Lock()
|
||
|
|
events["first_byte"]++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||
|
|
SetSkipTLSVerify(true).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} {
|
||
|
|
if events[key] == 0 {
|
||
|
|
t.Fatalf("expected trace event %q", key)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) {
|
||
|
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
var mu sync.Mutex
|
||
|
|
tlsStartCount := 0
|
||
|
|
tlsDoneCount := 0
|
||
|
|
var lastInfo TraceTLSHandshakeDoneInfo
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
tlsStartCount++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
tlsDoneCount++
|
||
|
|
lastInfo = info
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||
|
|
SetSkipTLSVerify(true).
|
||
|
|
SetDialTimeout(1500 * time.Millisecond).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
if tlsStartCount != 1 {
|
||
|
|
t.Fatalf("tlsStartCount=%d", tlsStartCount)
|
||
|
|
}
|
||
|
|
if tlsDoneCount != 1 {
|
||
|
|
t.Fatalf("tlsDoneCount=%d", tlsDoneCount)
|
||
|
|
}
|
||
|
|
if lastInfo.Err != nil {
|
||
|
|
t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err)
|
||
|
|
}
|
||
|
|
if lastInfo.ConnectionState.Version == 0 {
|
||
|
|
t.Fatal("expected tls connection state")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("ResolveTCPAddr() error: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var mu sync.Mutex
|
||
|
|
dnsStartCount := 0
|
||
|
|
dnsDoneCount := 0
|
||
|
|
var dnsStartHost string
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
DNSStart: func(info TraceDNSStartInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
dnsStartCount++
|
||
|
|
dnsStartHost = info.Host
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
dnsDoneCount++
|
||
|
|
mu.Unlock()
|
||
|
|
if info.Err != nil {
|
||
|
|
t.Errorf("unexpected dns error: %v", info.Err)
|
||
|
|
}
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
url := "http://trace.example.test:" + strconv.Itoa(addr.Port)
|
||
|
|
resp, err := NewSimpleRequest(url, http.MethodGet).
|
||
|
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||
|
|
return []net.IPAddr{{IP: addr.IP}}, nil
|
||
|
|
}).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
if dnsStartCount != 1 {
|
||
|
|
t.Fatalf("dnsStartCount=%d", dnsStartCount)
|
||
|
|
}
|
||
|
|
if dnsDoneCount != 1 {
|
||
|
|
t.Fatalf("dnsDoneCount=%d", dnsDoneCount)
|
||
|
|
}
|
||
|
|
if dnsStartHost != "trace.example.test" {
|
||
|
|
t.Fatalf("dnsStartHost=%q", dnsStartHost)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) {
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
var mu sync.Mutex
|
||
|
|
connectStartCount := 0
|
||
|
|
connectDoneCount := 0
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
ConnectStart: func(info TraceConnectStartInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
connectStartCount++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
ConnectDone: func(info TraceConnectDoneInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
connectDoneCount++
|
||
|
|
mu.Unlock()
|
||
|
|
if info.Err != nil {
|
||
|
|
t.Errorf("unexpected connect error: %v", info.Err)
|
||
|
|
}
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||
|
|
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
|
|
var dialer net.Dialer
|
||
|
|
return dialer.DialContext(context.Background(), network, addr)
|
||
|
|
}).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
if connectStartCount != 1 {
|
||
|
|
t.Fatalf("connectStartCount=%d", connectStartCount)
|
||
|
|
}
|
||
|
|
if connectDoneCount != 1 {
|
||
|
|
t.Fatalf("connectDoneCount=%d", connectDoneCount)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTraceHooksRetryEvents(t *testing.T) {
|
||
|
|
var hits int
|
||
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
hits++
|
||
|
|
if hits == 1 {
|
||
|
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
w.WriteHeader(http.StatusOK)
|
||
|
|
}))
|
||
|
|
defer server.Close()
|
||
|
|
|
||
|
|
var mu sync.Mutex
|
||
|
|
starts := 0
|
||
|
|
dones := 0
|
||
|
|
backoffs := 0
|
||
|
|
var finalDone TraceRetryAttemptDoneInfo
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
RetryAttemptStart: func(info TraceRetryAttemptStartInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
starts++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
dones++
|
||
|
|
finalDone = info
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
RetryBackoff: func(info TraceRetryBackoffInfo) {
|
||
|
|
mu.Lock()
|
||
|
|
backoffs++
|
||
|
|
mu.Unlock()
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
|
||
|
|
SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Do() error: %v", err)
|
||
|
|
}
|
||
|
|
defer resp.Close()
|
||
|
|
|
||
|
|
mu.Lock()
|
||
|
|
defer mu.Unlock()
|
||
|
|
if starts != 2 {
|
||
|
|
t.Fatalf("starts=%d", starts)
|
||
|
|
}
|
||
|
|
if dones != 2 {
|
||
|
|
t.Fatalf("dones=%d", dones)
|
||
|
|
}
|
||
|
|
if backoffs != 1 {
|
||
|
|
t.Fatalf("backoffs=%d", backoffs)
|
||
|
|
}
|
||
|
|
if finalDone.WillRetry {
|
||
|
|
t.Fatal("expected final attempt not to retry")
|
||
|
|
}
|
||
|
|
if finalDone.StatusCode != http.StatusOK {
|
||
|
|
t.Fatalf("final status=%d", finalDone.StatusCode)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) {
|
||
|
|
var gotErr error
|
||
|
|
hooks := &TraceHooks{
|
||
|
|
DNSDone: func(info TraceDNSDoneInfo) {
|
||
|
|
gotErr = info.Err
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet).
|
||
|
|
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
|
||
|
|
return nil, errors.New("lookup failed")
|
||
|
|
}).
|
||
|
|
SetTraceHooks(hooks).
|
||
|
|
Do()
|
||
|
|
if err == nil {
|
||
|
|
t.Fatal("expected request error")
|
||
|
|
}
|
||
|
|
if gotErr == nil || gotErr.Error() != "lookup failed" {
|
||
|
|
t.Fatalf("gotErr=%v", gotErr)
|
||
|
|
}
|
||
|
|
}
|