starnet/trace_summary_test.go

489 lines
14 KiB
Go
Raw Normal View History

package starnet
import (
"context"
"crypto/tls"
"net"
"net/http"
"strconv"
"strings"
"testing"
)
func TestTraceRecorderCapturesHTTPSummary(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
req := NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder)
resp, err := req.Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.Method != http.MethodGet {
t.Fatalf("method=%q", summary.Method)
}
if summary.URL != server.URL {
t.Fatalf("url=%q", summary.URL)
}
if summary.StatusCode != http.StatusOK {
t.Fatalf("status=%d", summary.StatusCode)
}
if summary.ResponseProto == "" {
t.Fatal("expected response proto")
}
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected request write timestamp")
}
if summary.FirstResponseByteAt.IsZero() {
t.Fatal("expected first response byte timestamp")
}
if summary.Conn.Addr == "" {
t.Fatal("expected get-conn target address")
}
if summary.TLS == nil {
t.Fatal("expected tls summary")
}
tlsSummary := summary.TLS
if tlsSummary.Version == 0 || tlsSummary.VersionName == "" {
t.Fatalf("unexpected tls version summary: %+v", tlsSummary)
}
if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" {
t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary)
}
if tlsSummary.ServerName == "" {
t.Fatal("expected tls server name")
}
if resp.TLS == nil {
t.Fatal("expected response TLS state")
}
if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol {
t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol)
}
if len(tlsSummary.PeerCertificates) == 0 {
t.Fatal("expected certificate summaries")
}
leaf := tlsSummary.PeerCertificates[0]
if leaf.Subject == "" || leaf.Issuer == "" {
t.Fatalf("unexpected leaf certificate summary: %+v", leaf)
}
if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 {
t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf)
}
if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK {
t.Fatalf("response trace summary=%+v", got)
}
}
func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String())
if err != nil {
t.Fatalf("ResolveTCPAddr() error: %v", err)
}
recorder := NewTraceRecorder()
targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port)
resp, err := NewSimpleRequest(targetURL, http.MethodGet).
SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: addr.IP}}, nil
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if summary.DNS == nil {
t.Fatal("expected dns summary")
}
if summary.DNS.Host != "trace-summary.example.test" {
t.Fatalf("dns host=%q", summary.DNS.Host)
}
if len(summary.DNS.Addrs) == 0 {
t.Fatal("expected resolved addresses")
}
if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) {
t.Fatalf("dns addrs=%v", summary.DNS.Addrs)
}
if summary.DNS.CompletedAt.IsZero() {
t.Fatal("expected dns completion timestamp")
}
if len(summary.Connect) == 0 {
t.Fatal("expected connect attempts")
}
connect := summary.Connect[0]
if connect.Network == "" || connect.Addr == "" {
t.Fatalf("unexpected connect summary: %+v", connect)
}
if connect.CompletedAt.IsZero() {
t.Fatalf("expected connect completion timestamp: %+v", connect)
}
}
func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) {
server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
if _, err := firstResp.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
if err := firstResp.Close(); err != nil {
t.Fatalf("first Close() error: %v", err)
}
recorder := NewTraceRecorder()
secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet).
SetSkipTLSVerify(true).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer secondResp.Close()
if _, err := secondResp.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
summary := recorder.Summary()
if !summary.Conn.Reused {
t.Fatalf("expected reused connection summary, got %+v", summary.Conn)
}
if summary.TLS == nil || summary.TLS.Version == 0 {
t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS)
}
}
func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
recorder := NewTraceRecorder()
wroteRequest := 0
resp, err := NewSimpleRequest(server.URL, http.MethodGet).
SetTraceHooks(&TraceHooks{
WroteRequest: func(info TraceWroteRequestInfo) {
wroteRequest++
},
}).
SetTraceRecorder(recorder).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if wroteRequest == 0 {
t.Fatal("expected custom trace hook to run")
}
summary := recorder.Summary()
if summary.RequestWrittenAt.IsZero() {
t.Fatal("expected recorder to capture wrote-request event")
}
}
func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) {
recorder := NewTraceRecorder()
hooks := recorder.Hooks()
hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}},
})
hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"})
hooks.DNSDone(TraceDNSDoneInfo{
Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}},
})
summary := recorder.Summary()
if len(summary.DNSEvents) != 2 {
t.Fatalf("dns events=%d", len(summary.DNSEvents))
}
if summary.DNSEvents[0].Host != "target.example.test" {
t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host)
}
if summary.DNSEvents[1].Host != "proxy.example.test" {
t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host)
}
if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" {
t.Fatalf("last dns summary=%+v", summary.DNS)
}
}
func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "localhost")
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
wantAddr := strings.TrimPrefix(targetURL, "https://")
if startInfo.Network != "tcp" {
t.Fatalf("start network=%q", startInfo.Network)
}
if startInfo.Addr != wantAddr {
t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr)
}
if startInfo.ServerName != "localhost" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" {
t.Fatalf("done info=%+v", doneInfo)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}
func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) {
var captured []string
traceState := newTraceState(&TraceHooks{
WroteHeaderField: func(info TraceWroteHeaderFieldInfo) {
captured = info.Values
},
})
trace := traceState.clientTrace()
values := []string{"a", "b"}
trace.WroteHeaderField("X-Test", values)
values[0] = "mutated"
if len(captured) != 2 {
t.Fatalf("captured=%v", captured)
}
if captured[0] != "a" {
t.Fatalf("captured=%v", captured)
}
}
func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
recorder := NewTraceRecorder()
client := NewClientNoErr(WithTraceRecorder(recorder))
req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet)
resp1, err := req1.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req2 := req1.Clone().SetURL(server.URL + "/two")
resp2, err := req2.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("req1 trace summary=%+v", got)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("req2 trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" {
t.Fatalf("resp2 trace summary=%+v", got)
}
if got := recorder.Summary(); got.URL != server.URL+"/two" {
t.Fatalf("shared recorder summary=%+v", got)
}
}
func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) {
server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(r.URL.Path))
}))
defer server.Close()
req := NewSimpleRequest(server.URL+"/first", http.MethodGet).
SetTraceRecorder(NewTraceRecorder())
resp1, err := req.Do()
if err != nil {
t.Fatalf("first Do() error: %v", err)
}
defer resp1.Close()
if _, err := resp1.Body().Bytes(); err != nil {
t.Fatalf("first Body().Bytes() error: %v", err)
}
req.SetURL(server.URL + "/second")
resp2, err := req.Do()
if err != nil {
t.Fatalf("second Do() error: %v", err)
}
defer resp2.Close()
if _, err := resp2.Body().Bytes(); err != nil {
t.Fatalf("second Body().Bytes() error: %v", err)
}
if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" {
t.Fatalf("resp1 trace summary=%+v", got)
}
if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("request trace summary=%+v", got)
}
if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" {
t.Fatalf("resp2 trace summary=%+v", got)
}
}
func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) {
server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()
client := NewClientNoErr()
transport, ok := client.HTTPClient().Transport.(*Transport)
if !ok {
t.Fatalf("transport type=%T", client.HTTPClient().Transport)
}
base := newBaseHTTPTransport()
base.TLSClientConfig = &tls.Config{RootCAs: pool}
transport.SetBase(base)
targetURL := httpsURLForHost(t, server, "trace-custom.example.test")
serverAddr := server.Listener.Addr().String()
var startInfo TraceTLSHandshakeStartInfo
var doneInfo TraceTLSHandshakeDoneInfo
resp, err := client.NewSimpleRequest(targetURL, http.MethodGet).
SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr)
}).
SetTraceHooks(&TraceHooks{
TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) {
startInfo = info
},
TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) {
doneInfo = info
},
}).
Do()
if err != nil {
t.Fatalf("Do() error: %v", err)
}
defer resp.Close()
if _, err := resp.Body().Bytes(); err != nil {
t.Fatalf("Body().Bytes() error: %v", err)
}
if startInfo.Network != "" || startInfo.Addr != "" {
t.Fatalf("start info=%+v", startInfo)
}
if doneInfo.Network != "" || doneInfo.Addr != "" {
t.Fatalf("done info=%+v", doneInfo)
}
if startInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("start server name=%q", startInfo.ServerName)
}
if doneInfo.ServerName != "trace-custom.example.test" {
t.Fatalf("done server name=%q", doneInfo.ServerName)
}
if doneInfo.ConnectionState.Version == 0 {
t.Fatalf("done state=%+v", doneInfo.ConnectionState)
}
}