package notify import ( "context" "errors" "io" "net" "testing" "time" ) func BenchmarkModernPSKSealPlainThroughput(b *testing.B) { cases := []struct { name string payloadSize int }{ { name: "seal_1MiB", payloadSize: 1024 * 1024, }, { name: "seal_4MiB", payloadSize: 4 * 1024 * 1024, }, } key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions()) if err != nil { b.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { payload := make([]byte, tc.payloadSize) for i := range payload { payload[i] = byte(i) } var sink []byte b.ReportAllocs() b.SetBytes(int64(tc.payloadSize)) b.ResetTimer() for i := 0; i < b.N; i++ { wire, err := transport.fastPlainEncode(key, len(payload), func(dst []byte) error { copy(dst, payload) return nil }) if err != nil { b.Fatalf("fastPlainEncode failed: %v", err) } sink = wire } b.StopTimer() _ = sink }) } } func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) { cases := []struct { name string payloadSize int }{ { name: "wire_1MiB", payloadSize: 1024 * 1024, }, { name: "wire_4MiB", payloadSize: 4 * 1024 * 1024, }, } key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions()) if err != nil { b.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { benchmarkDedicatedWireLocalhostThroughput(b, key, transport.fastPlainEncode, tc.payloadSize) }) } } func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) { b.Helper() listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatalf("net.Listen failed: %v", err) } b.Cleanup(func() { _ = listener.Close() }) acceptCh := make(chan net.Conn, 1) acceptErrCh := make(chan error, 1) go func() { conn, err := listener.Accept() if err != nil { acceptErrCh <- err return } acceptCh <- conn }() clientConn, err := net.Dial("tcp", listener.Addr().String()) if err != nil { b.Fatalf("net.Dial failed: %v", err) } b.Cleanup(func() { _ = clientConn.Close() }) if tcpConn, ok := clientConn.(*net.TCPConn); ok { _ = tcpConn.SetNoDelay(true) } var serverConn net.Conn select { case conn := <-acceptCh: serverConn = conn case err := <-acceptErrCh: b.Fatalf("Accept failed: %v", err) case <-time.After(5 * time.Second): b.Fatal("timed out waiting for accept") } b.Cleanup(func() { if serverConn != nil { _ = serverConn.Close() } }) drainDone := make(chan error, 1) go func() { _, err := io.Copy(io.Discard, serverConn) if err != nil && !errors.Is(err, io.EOF) { drainDone <- err return } drainDone <- nil }() sender := newBulkDedicatedSender(clientConn, 1, func(plain []byte) ([]byte, error) { return encode(key, len(plain), func(dst []byte) error { copy(dst, plain) return nil }) }, func(items []bulkDedicatedSendRequest) ([]byte, error) { return encodeBulkDedicatedBatchPayloadFast(encode, key, 1, items) }, nil) defer sender.stop() payload := make([]byte, payloadSize) for i := range payload { payload[i] = byte(i) } b.ReportAllocs() b.SetBytes(int64(payloadSize)) b.ResetTimer() seq := uint64(1) for i := 0; i < b.N; i++ { n, err := sender.submitWrite(context.Background(), seq, payload, payloadSize) if err != nil { b.Fatalf("submitWrite failed at iter %d: %v", i, err) } if n != len(payload) { b.Fatalf("submitWrite bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) } seq++ } b.StopTimer() _ = clientConn.Close() select { case err := <-drainDone: if err != nil { b.Fatalf("server drain failed: %v", err) } case <-time.After(10 * time.Second): b.Fatal("timed out waiting for server drain") } }