package starnet import ( "context" "crypto/tls" "net" "net/http" "strconv" "strings" "testing" ) func TestTraceRecorderCapturesHTTPSummary(t *testing.T) { server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() recorder := NewTraceRecorder() req := NewSimpleRequest(server.URL, http.MethodGet). SetSkipTLSVerify(true). SetTraceRecorder(recorder) resp, err := req.Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } summary := recorder.Summary() if summary.Method != http.MethodGet { t.Fatalf("method=%q", summary.Method) } if summary.URL != server.URL { t.Fatalf("url=%q", summary.URL) } if summary.StatusCode != http.StatusOK { t.Fatalf("status=%d", summary.StatusCode) } if summary.ResponseProto == "" { t.Fatal("expected response proto") } if summary.RequestWrittenAt.IsZero() { t.Fatal("expected request write timestamp") } if summary.FirstResponseByteAt.IsZero() { t.Fatal("expected first response byte timestamp") } if summary.Conn.Addr == "" { t.Fatal("expected get-conn target address") } if summary.TLS == nil { t.Fatal("expected tls summary") } tlsSummary := summary.TLS if tlsSummary.Version == 0 || tlsSummary.VersionName == "" { t.Fatalf("unexpected tls version summary: %+v", tlsSummary) } if tlsSummary.CipherSuite == 0 || tlsSummary.CipherSuiteName == "" { t.Fatalf("unexpected cipher suite summary: %+v", tlsSummary) } if tlsSummary.ServerName == "" { t.Fatal("expected tls server name") } if resp.TLS == nil { t.Fatal("expected response TLS state") } if tlsSummary.NegotiatedProtocol != resp.TLS.NegotiatedProtocol { t.Fatalf("alpn=%q resp=%q", tlsSummary.NegotiatedProtocol, resp.TLS.NegotiatedProtocol) } if len(tlsSummary.PeerCertificates) == 0 { t.Fatal("expected certificate summaries") } leaf := tlsSummary.PeerCertificates[0] if leaf.Subject == "" || leaf.Issuer == "" { t.Fatalf("unexpected leaf certificate summary: %+v", leaf) } if len(leaf.DNSNames) == 0 && len(leaf.IPAddresses) == 0 { t.Fatalf("expected DNS or IP SANs in leaf certificate: %+v", leaf) } if got := req.TraceSummary(); got == nil || got.StatusCode != http.StatusOK { t.Fatalf("request trace summary=%+v", got) } if got := resp.TraceSummary(); got == nil || got.StatusCode != http.StatusOK { t.Fatalf("response trace summary=%+v", got) } } func TestTraceRecorderCapturesDNSAndConnectSummary(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) if err != nil { t.Fatalf("ResolveTCPAddr() error: %v", err) } recorder := NewTraceRecorder() targetURL := "http://trace-summary.example.test:" + strconv.Itoa(addr.Port) resp, err := NewSimpleRequest(targetURL, http.MethodGet). SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: addr.IP}}, nil }). SetTraceRecorder(recorder). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } summary := recorder.Summary() if summary.DNS == nil { t.Fatal("expected dns summary") } if summary.DNS.Host != "trace-summary.example.test" { t.Fatalf("dns host=%q", summary.DNS.Host) } if len(summary.DNS.Addrs) == 0 { t.Fatal("expected resolved addresses") } if !strings.Contains(summary.DNS.Addrs[0], addr.IP.String()) { t.Fatalf("dns addrs=%v", summary.DNS.Addrs) } if summary.DNS.CompletedAt.IsZero() { t.Fatal("expected dns completion timestamp") } if len(summary.Connect) == 0 { t.Fatal("expected connect attempts") } connect := summary.Connect[0] if connect.Network == "" || connect.Addr == "" { t.Fatalf("unexpected connect summary: %+v", connect) } if connect.CompletedAt.IsZero() { t.Fatalf("expected connect completion timestamp: %+v", connect) } } func TestTraceRecorderUsesResponseTLSForReusedConnection(t *testing.T) { server := newIPv4TLSServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() client := NewClientNoErr() firstResp, err := client.NewSimpleRequest(server.URL, http.MethodGet). SetSkipTLSVerify(true). Do() if err != nil { t.Fatalf("first Do() error: %v", err) } if _, err := firstResp.Body().Bytes(); err != nil { t.Fatalf("first Body().Bytes() error: %v", err) } if err := firstResp.Close(); err != nil { t.Fatalf("first Close() error: %v", err) } recorder := NewTraceRecorder() secondResp, err := client.NewSimpleRequest(server.URL, http.MethodGet). SetSkipTLSVerify(true). SetTraceRecorder(recorder). Do() if err != nil { t.Fatalf("second Do() error: %v", err) } defer secondResp.Close() if _, err := secondResp.Body().Bytes(); err != nil { t.Fatalf("second Body().Bytes() error: %v", err) } summary := recorder.Summary() if !summary.Conn.Reused { t.Fatalf("expected reused connection summary, got %+v", summary.Conn) } if summary.TLS == nil || summary.TLS.Version == 0 { t.Fatalf("expected tls summary from response fallback, got %+v", summary.TLS) } } func TestTraceRecorderCoexistsWithTraceHooks(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() recorder := NewTraceRecorder() wroteRequest := 0 resp, err := NewSimpleRequest(server.URL, http.MethodGet). SetTraceHooks(&TraceHooks{ WroteRequest: func(info TraceWroteRequestInfo) { wroteRequest++ }, }). SetTraceRecorder(recorder). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } if wroteRequest == 0 { t.Fatal("expected custom trace hook to run") } summary := recorder.Summary() if summary.RequestWrittenAt.IsZero() { t.Fatal("expected recorder to capture wrote-request event") } } func TestTraceRecorderPreservesMultipleDNSEvents(t *testing.T) { recorder := NewTraceRecorder() hooks := recorder.Hooks() hooks.DNSStart(TraceDNSStartInfo{Host: "target.example.test"}) hooks.DNSDone(TraceDNSDoneInfo{ Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, }) hooks.DNSStart(TraceDNSStartInfo{Host: "proxy.example.test"}) hooks.DNSDone(TraceDNSDoneInfo{ Addrs: []net.IPAddr{{IP: net.ParseIP("127.0.0.2")}}, }) summary := recorder.Summary() if len(summary.DNSEvents) != 2 { t.Fatalf("dns events=%d", len(summary.DNSEvents)) } if summary.DNSEvents[0].Host != "target.example.test" { t.Fatalf("first dns host=%q", summary.DNSEvents[0].Host) } if summary.DNSEvents[1].Host != "proxy.example.test" { t.Fatalf("second dns host=%q", summary.DNSEvents[1].Host) } if summary.DNS == nil || summary.DNS.Host != "proxy.example.test" { t.Fatalf("last dns summary=%+v", summary.DNS) } } func TestTraceHooksStandardTLSPathIncludesMetadata(t *testing.T) { server, pool := newTrustedIPv4TLSServer(t, "localhost", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() client := NewClientNoErr() transport, ok := client.HTTPClient().Transport.(*Transport) if !ok { t.Fatalf("transport type=%T", client.HTTPClient().Transport) } base := newBaseHTTPTransport() base.TLSClientConfig = &tls.Config{RootCAs: pool} transport.SetBase(base) targetURL := httpsURLForHost(t, server, "localhost") var startInfo TraceTLSHandshakeStartInfo var doneInfo TraceTLSHandshakeDoneInfo resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). SetTraceHooks(&TraceHooks{ TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { startInfo = info }, TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { doneInfo = info }, }). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } wantAddr := strings.TrimPrefix(targetURL, "https://") if startInfo.Network != "tcp" { t.Fatalf("start network=%q", startInfo.Network) } if startInfo.Addr != wantAddr { t.Fatalf("start addr=%q want=%q", startInfo.Addr, wantAddr) } if startInfo.ServerName != "localhost" { t.Fatalf("start server name=%q", startInfo.ServerName) } if doneInfo.Network != "tcp" || doneInfo.Addr != wantAddr || doneInfo.ServerName != "localhost" { t.Fatalf("done info=%+v", doneInfo) } if doneInfo.ConnectionState.Version == 0 { t.Fatalf("done state=%+v", doneInfo.ConnectionState) } } func TestTraceHooksWroteHeaderFieldCopiesValues(t *testing.T) { var captured []string traceState := newTraceState(&TraceHooks{ WroteHeaderField: func(info TraceWroteHeaderFieldInfo) { captured = info.Values }, }) trace := traceState.clientTrace() values := []string{"a", "b"} trace.WroteHeaderField("X-Test", values) values[0] = "mutated" if len(captured) != 2 { t.Fatalf("captured=%v", captured) } if captured[0] != "a" { t.Fatalf("captured=%v", captured) } } func TestTraceRecorderSharedAcrossCloneKeepsPerRequestSummaries(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(r.URL.Path)) })) defer server.Close() recorder := NewTraceRecorder() client := NewClientNoErr(WithTraceRecorder(recorder)) req1 := client.NewSimpleRequest(server.URL+"/one", http.MethodGet) resp1, err := req1.Do() if err != nil { t.Fatalf("first Do() error: %v", err) } defer resp1.Close() if _, err := resp1.Body().Bytes(); err != nil { t.Fatalf("first Body().Bytes() error: %v", err) } req2 := req1.Clone().SetURL(server.URL + "/two") resp2, err := req2.Do() if err != nil { t.Fatalf("second Do() error: %v", err) } defer resp2.Close() if _, err := resp2.Body().Bytes(); err != nil { t.Fatalf("second Body().Bytes() error: %v", err) } if got := req1.TraceSummary(); got == nil || got.URL != server.URL+"/one" { t.Fatalf("req1 trace summary=%+v", got) } if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/one" { t.Fatalf("resp1 trace summary=%+v", got) } if got := req2.TraceSummary(); got == nil || got.URL != server.URL+"/two" { t.Fatalf("req2 trace summary=%+v", got) } if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/two" { t.Fatalf("resp2 trace summary=%+v", got) } if got := recorder.Summary(); got.URL != server.URL+"/two" { t.Fatalf("shared recorder summary=%+v", got) } } func TestResponseTraceSummaryIsStableAcrossRequestReuse(t *testing.T) { server := newIPv4Server(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(r.URL.Path)) })) defer server.Close() req := NewSimpleRequest(server.URL+"/first", http.MethodGet). SetTraceRecorder(NewTraceRecorder()) resp1, err := req.Do() if err != nil { t.Fatalf("first Do() error: %v", err) } defer resp1.Close() if _, err := resp1.Body().Bytes(); err != nil { t.Fatalf("first Body().Bytes() error: %v", err) } req.SetURL(server.URL + "/second") resp2, err := req.Do() if err != nil { t.Fatalf("second Do() error: %v", err) } defer resp2.Close() if _, err := resp2.Body().Bytes(); err != nil { t.Fatalf("second Body().Bytes() error: %v", err) } if got := resp1.TraceSummary(); got == nil || got.URL != server.URL+"/first" { t.Fatalf("resp1 trace summary=%+v", got) } if got := req.TraceSummary(); got == nil || got.URL != server.URL+"/second" { t.Fatalf("request trace summary=%+v", got) } if got := resp2.TraceSummary(); got == nil || got.URL != server.URL+"/second" { t.Fatalf("resp2 trace summary=%+v", got) } } func TestTraceHooksCustomDialDoesNotInventTLSAddr(t *testing.T) { server, pool := newTrustedIPv4TLSServer(t, "trace-custom.example.test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() client := NewClientNoErr() transport, ok := client.HTTPClient().Transport.(*Transport) if !ok { t.Fatalf("transport type=%T", client.HTTPClient().Transport) } base := newBaseHTTPTransport() base.TLSClientConfig = &tls.Config{RootCAs: pool} transport.SetBase(base) targetURL := httpsURLForHost(t, server, "trace-custom.example.test") serverAddr := server.Listener.Addr().String() var startInfo TraceTLSHandshakeStartInfo var doneInfo TraceTLSHandshakeDoneInfo resp, err := client.NewSimpleRequest(targetURL, http.MethodGet). SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { return (&net.Dialer{}).DialContext(ctx, "tcp", serverAddr) }). SetTraceHooks(&TraceHooks{ TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { startInfo = info }, TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { doneInfo = info }, }). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() if _, err := resp.Body().Bytes(); err != nil { t.Fatalf("Body().Bytes() error: %v", err) } if startInfo.Network != "" || startInfo.Addr != "" { t.Fatalf("start info=%+v", startInfo) } if doneInfo.Network != "" || doneInfo.Addr != "" { t.Fatalf("done info=%+v", doneInfo) } if startInfo.ServerName != "trace-custom.example.test" { t.Fatalf("start server name=%q", startInfo.ServerName) } if doneInfo.ServerName != "trace-custom.example.test" { t.Fatalf("done server name=%q", doneInfo.ServerName) } if doneInfo.ConnectionState.Version == 0 { t.Fatalf("done state=%+v", doneInfo.ConnectionState) } }