package starnet import ( "context" "errors" "net" "net/http" "net/http/httptest" "strconv" "sync" "testing" "time" ) func TestTraceHooksStandardHTTPSPath(t *testing.T) { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer server.Close() var mu sync.Mutex events := map[string]int{} hooks := &TraceHooks{ GetConn: func(info TraceGetConnInfo) { mu.Lock() events["get_conn"]++ mu.Unlock() }, GotConn: func(info TraceGotConnInfo) { mu.Lock() events["got_conn"]++ mu.Unlock() }, TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { mu.Lock() events["tls_start"]++ mu.Unlock() }, TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { mu.Lock() events["tls_done"]++ mu.Unlock() if info.Err != nil { t.Errorf("unexpected tls handshake error: %v", info.Err) } }, WroteHeaders: func() { mu.Lock() events["wrote_headers"]++ mu.Unlock() }, WroteRequest: func(info TraceWroteRequestInfo) { mu.Lock() events["wrote_request"]++ mu.Unlock() if info.Err != nil { t.Errorf("unexpected write error: %v", info.Err) } }, GotFirstResponseByte: func() { mu.Lock() events["first_byte"]++ mu.Unlock() }, } resp, err := NewSimpleRequest(server.URL, http.MethodGet). SetSkipTLSVerify(true). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() for _, key := range []string{"get_conn", "got_conn", "tls_start", "tls_done", "wrote_headers", "wrote_request", "first_byte"} { if events[key] == 0 { t.Fatalf("expected trace event %q", key) } } } func TestTraceHooksDynamicHTTPSPathDoesNotDuplicateTLSHandshake(t *testing.T) { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() var mu sync.Mutex tlsStartCount := 0 tlsDoneCount := 0 var lastInfo TraceTLSHandshakeDoneInfo hooks := &TraceHooks{ TLSHandshakeStart: func(info TraceTLSHandshakeStartInfo) { mu.Lock() tlsStartCount++ mu.Unlock() }, TLSHandshakeDone: func(info TraceTLSHandshakeDoneInfo) { mu.Lock() tlsDoneCount++ lastInfo = info mu.Unlock() }, } resp, err := NewSimpleRequest(server.URL, http.MethodGet). SetSkipTLSVerify(true). SetDialTimeout(1500 * time.Millisecond). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if tlsStartCount != 1 { t.Fatalf("tlsStartCount=%d", tlsStartCount) } if tlsDoneCount != 1 { t.Fatalf("tlsDoneCount=%d", tlsDoneCount) } if lastInfo.Err != nil { t.Fatalf("unexpected tls handshake error: %v", lastInfo.Err) } if lastInfo.ConnectionState.Version == 0 { t.Fatal("expected tls connection state") } } func TestTraceHooksCustomLookupFuncEmitsDNSEvents(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() addr, err := net.ResolveTCPAddr("tcp", server.Listener.Addr().String()) if err != nil { t.Fatalf("ResolveTCPAddr() error: %v", err) } var mu sync.Mutex dnsStartCount := 0 dnsDoneCount := 0 var dnsStartHost string hooks := &TraceHooks{ DNSStart: func(info TraceDNSStartInfo) { mu.Lock() dnsStartCount++ dnsStartHost = info.Host mu.Unlock() }, DNSDone: func(info TraceDNSDoneInfo) { mu.Lock() dnsDoneCount++ mu.Unlock() if info.Err != nil { t.Errorf("unexpected dns error: %v", info.Err) } }, } url := "http://trace.example.test:" + strconv.Itoa(addr.Port) resp, err := NewSimpleRequest(url, http.MethodGet). SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: addr.IP}}, nil }). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if dnsStartCount != 1 { t.Fatalf("dnsStartCount=%d", dnsStartCount) } if dnsDoneCount != 1 { t.Fatalf("dnsDoneCount=%d", dnsDoneCount) } if dnsStartHost != "trace.example.test" { t.Fatalf("dnsStartHost=%q", dnsStartHost) } } func TestTraceHooksCustomDialFuncEmitsConnectEvents(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() var mu sync.Mutex connectStartCount := 0 connectDoneCount := 0 hooks := &TraceHooks{ ConnectStart: func(info TraceConnectStartInfo) { mu.Lock() connectStartCount++ mu.Unlock() }, ConnectDone: func(info TraceConnectDoneInfo) { mu.Lock() connectDoneCount++ mu.Unlock() if info.Err != nil { t.Errorf("unexpected connect error: %v", info.Err) } }, } resp, err := NewSimpleRequest(server.URL, http.MethodGet). SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { var dialer net.Dialer return dialer.DialContext(context.Background(), network, addr) }). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if connectStartCount != 1 { t.Fatalf("connectStartCount=%d", connectStartCount) } if connectDoneCount != 1 { t.Fatalf("connectDoneCount=%d", connectDoneCount) } } func TestTraceHooksRetryEvents(t *testing.T) { var hits int server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits++ if hits == 1 { w.WriteHeader(http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) })) defer server.Close() var mu sync.Mutex starts := 0 dones := 0 backoffs := 0 var finalDone TraceRetryAttemptDoneInfo hooks := &TraceHooks{ RetryAttemptStart: func(info TraceRetryAttemptStartInfo) { mu.Lock() starts++ mu.Unlock() }, RetryAttemptDone: func(info TraceRetryAttemptDoneInfo) { mu.Lock() dones++ finalDone = info mu.Unlock() }, RetryBackoff: func(info TraceRetryBackoffInfo) { mu.Lock() backoffs++ mu.Unlock() }, } resp, err := NewSimpleRequest(server.URL, http.MethodGet). SetRetry(1, WithRetryBackoff(time.Millisecond, time.Millisecond, 1), WithRetryJitter(0)). SetTraceHooks(hooks). Do() if err != nil { t.Fatalf("Do() error: %v", err) } defer resp.Close() mu.Lock() defer mu.Unlock() if starts != 2 { t.Fatalf("starts=%d", starts) } if dones != 2 { t.Fatalf("dones=%d", dones) } if backoffs != 1 { t.Fatalf("backoffs=%d", backoffs) } if finalDone.WillRetry { t.Fatal("expected final attempt not to retry") } if finalDone.StatusCode != http.StatusOK { t.Fatalf("final status=%d", finalDone.StatusCode) } } func TestTraceHooksCustomLookupFuncPropagatesDNSError(t *testing.T) { var gotErr error hooks := &TraceHooks{ DNSDone: func(info TraceDNSDoneInfo) { gotErr = info.Err }, } _, err := NewSimpleRequest("http://trace.example.test:80", http.MethodGet). SetLookupFunc(func(ctx context.Context, host string) ([]net.IPAddr, error) { return nil, errors.New("lookup failed") }). SetTraceHooks(hooks). Do() if err == nil { t.Fatal("expected request error") } if gotErr == nil || gotErr.Error() != "lookup failed" { t.Fatalf("gotErr=%v", gotErr) } }