package notify import ( "context" "errors" "io" "path/filepath" "runtime" "sync" "testing" "time" ) func BenchmarkBulkEndToEndThroughput(b *testing.B) { cases := []struct { name string network string payloadSize int dedicated bool }{ { name: "tcp_shared_1MiB", network: "tcp", payloadSize: 1024 * 1024, }, { name: "tcp_dedicated_1MiB", network: "tcp", payloadSize: 1024 * 1024, dedicated: true, }, { name: "unix_shared_1MiB", network: "unix", payloadSize: 1024 * 1024, }, { name: "unix_dedicated_1MiB", network: "unix", payloadSize: 1024 * 1024, dedicated: true, }, } for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { benchmarkBulkEndToEndThroughputNetwork(b, tc.network, tc.payloadSize, tc.dedicated) }) } } func BenchmarkBulkEndToEndThroughputConcurrent(b *testing.B) { cases := []struct { name string network string payloadSize int concurrency int dedicated bool }{ { name: "tcp_dedicated_4x1MiB", network: "tcp", payloadSize: 1024 * 1024, concurrency: 4, dedicated: true, }, { name: "unix_dedicated_4x1MiB", network: "unix", payloadSize: 1024 * 1024, concurrency: 4, dedicated: true, }, } for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { benchmarkBulkEndToEndThroughputConcurrentNetwork(b, tc.network, tc.payloadSize, tc.concurrency, tc.dedicated) }) } } func benchmarkBulkEndToEndThroughputNetwork(b *testing.B, network string, payloadSize int, dedicated bool) { b.Helper() if network == "unix" && runtime.GOOS == "windows" { b.Skip("unix socket is not available on windows") } server := newBulkBenchmarkServer(b, network) client := newBulkBenchmarkClient(b, network, server) totalBytes := int64(payloadSize) if b.N > 1 { totalBytes = int64(payloadSize) * int64(b.N) } bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{ Range: BulkRange{ Offset: 0, Length: totalBytes, }, ChunkSize: payloadSize, Dedicated: dedicated, }) drainDone := make(chan error, 1) go func() { _, err := io.Copy(io.Discard, accepted.Bulk) if err != nil && !errors.Is(err, io.EOF) { drainDone <- err return } drainDone <- nil }() payload := make([]byte, payloadSize) for i := range payload { payload[i] = byte(i) } b.ReportAllocs() b.SetBytes(int64(payloadSize)) b.ResetTimer() for i := 0; i < b.N; i++ { n, err := bulk.Write(payload) if err != nil { b.Fatalf("bulk Write failed at iter %d: %v", i, err) } if n != len(payload) { b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) } } if err := bulk.CloseWrite(); err != nil { b.Fatalf("bulk CloseWrite failed: %v", err) } select { case err := <-drainDone: if err != nil { b.Fatalf("server drain failed: %v", err) } case <-time.After(15 * time.Second): b.Fatal("timed out waiting for server drain") } b.StopTimer() _ = accepted.Bulk.Close() _ = bulk.Close() } func benchmarkBulkEndToEndThroughputConcurrentNetwork(b *testing.B, network string, payloadSize int, concurrency int, dedicated bool) { b.Helper() if concurrency <= 0 { b.Fatal("concurrency must be > 0") } if network == "unix" && runtime.GOOS == "windows" { b.Skip("unix socket is not available on windows") } server := newBulkBenchmarkServer(b, network) client := newBulkBenchmarkClient(b, network, server) totalBytes := int64(payloadSize) if b.N > 1 { totalBytes = int64(payloadSize) * int64(b.N) } bulks := make([]Bulk, 0, concurrency) acceptedBulks := make([]Bulk, 0, concurrency) for index := 0; index < concurrency; index++ { bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{ Range: BulkRange{ Offset: int64(index) * totalBytes, Length: totalBytes, }, ChunkSize: payloadSize, Dedicated: dedicated, }) bulks = append(bulks, bulk) acceptedBulks = append(acceptedBulks, accepted.Bulk) } drainDone := make(chan error, concurrency) for _, acceptedBulk := range acceptedBulks { bulk := acceptedBulk go func() { _, err := io.Copy(io.Discard, bulk) if err != nil && !errors.Is(err, io.EOF) { drainDone <- err return } drainDone <- nil }() } payload := make([]byte, payloadSize) for i := range payload { payload[i] = byte(i) } b.ReportAllocs() b.SetBytes(int64(payloadSize)) b.ResetTimer() var wg sync.WaitGroup errCh := make(chan error, concurrency) for index, bulk := range bulks { count := b.N / concurrency if index < b.N%concurrency { count++ } wg.Add(1) go func(bulk Bulk, count int) { defer wg.Done() for i := 0; i < count; i++ { n, err := bulk.Write(payload) if err != nil { errCh <- err return } if n != len(payload) { errCh <- errors.New("bulk write bytes mismatch") return } } }(bulk, count) } wg.Wait() close(errCh) for err := range errCh { if err != nil { b.Fatalf("concurrent bulk write failed: %v", err) } } for index, bulk := range bulks { if err := bulk.CloseWrite(); err != nil { b.Fatalf("bulk %d CloseWrite failed: %v", index, err) } } for index := 0; index < concurrency; index++ { select { case err := <-drainDone: if err != nil { b.Fatalf("server drain failed: %v", err) } case <-time.After(15 * time.Second): b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency) } } b.StopTimer() for _, bulk := range acceptedBulks { _ = bulk.Close() } for _, bulk := range bulks { _ = bulk.Close() } } type bulkBenchmarkServer struct { server *ServerCommon acceptCh chan BulkAcceptInfo addr string } func newBulkBenchmarkServer(tb testing.TB, network string) bulkBenchmarkServer { tb.Helper() server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { tb.Fatalf("UseModernPSKServer failed: %v", err) } if network == "udp" { if err := UseSignalReliabilityServer(server, bulkBenchmarkSignalReliabilityOptions()); err != nil { tb.Fatalf("UseSignalReliabilityServer failed: %v", err) } } acceptCh := make(chan BulkAcceptInfo, 32) server.SetBulkHandler(func(info BulkAcceptInfo) error { acceptCh <- info return nil }) addr := bulkBenchmarkListenAddr(tb, network) if err := server.Listen(network, addr); err != nil { tb.Fatalf("server Listen failed: %v", err) } tb.Cleanup(func() { _ = server.Stop() }) return bulkBenchmarkServer{ server: server, acceptCh: acceptCh, addr: signalRoundTripServerAddr(server, addr), } } func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkServer) *ClientCommon { tb.Helper() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { tb.Fatalf("UseModernPSKClient failed: %v", err) } if network == "udp" { if err := UseSignalReliabilityClient(client, bulkBenchmarkSignalReliabilityOptions()); err != nil { tb.Fatalf("UseSignalReliabilityClient failed: %v", err) } } if err := client.Connect(network, server.addr); err != nil { tb.Fatalf("client Connect failed: %v", err) } tb.Cleanup(func() { _ = client.Stop() }) return client } func openBenchmarkBulkPair(tb testing.TB, client *ClientCommon, acceptCh <-chan BulkAcceptInfo, opt BulkOpenOptions) (Bulk, BulkAcceptInfo) { tb.Helper() bulk, err := client.OpenBulk(context.Background(), opt) if err != nil { tb.Fatalf("client OpenBulk failed: %v", err) } return bulk, waitBenchmarkAcceptedBulk(tb, acceptCh, 5*time.Second) } func bulkBenchmarkListenAddr(tb testing.TB, network string) string { tb.Helper() switch network { case "unix": return filepath.Join(tb.TempDir(), "notify-bulk.sock") case "udp", "tcp": return "127.0.0.1:0" default: tb.Fatalf("unsupported benchmark network %q", network) return "" } } func bulkBenchmarkSignalReliabilityOptions() *SignalReliabilityOptions { return &SignalReliabilityOptions{ Enabled: true, AckTimeout: 3 * time.Second, SendRetry: 8, ReceiveCacheLimit: 512, } }