From b218e76328884caeca60b1de457db47babcd4d42 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 30 May 2025 15:25:37 +0800 Subject: [PATCH] mldsa: add benchmark for Verify --- mldsa/field_test.go | 56 +++++++++++++++---------------------------- mldsa/mldsa44_test.go | 20 ++++++++++++++++ mldsa/mldsa65_test.go | 20 ++++++++++++++++ mldsa/mldsa87_test.go | 20 ++++++++++++++++ 4 files changed, 79 insertions(+), 37 deletions(-) diff --git a/mldsa/field_test.go b/mldsa/field_test.go index f07d59d..8df650f 100644 --- a/mldsa/field_test.go +++ b/mldsa/field_test.go @@ -7,7 +7,6 @@ package mldsa import ( - "fmt" "math/big" mathrand "math/rand/v2" "testing" @@ -110,10 +109,6 @@ func TestFieldMul(t *testing.T) { } } } - for _, z := range zetasMontgomery { - fmt.Printf("%v, ", fieldReduce(uint64(z))) - } - fmt.Println() } func TestFieldBarrettMul(t *testing.T) { @@ -138,67 +133,54 @@ func randomRingElement() ringElement { func TestNTT(t *testing.T) { r := randomRingElement() - r1 := r - r2 := ntt(r) - r3 := barrettNTT(r1) - for i, v := range r3 { - if v != r2[i] { - t.Errorf("expected %v, got %v", v, r2[i]) + rNTT := ntt(r) + rBarretNTT := barrettNTT(r) + for i, v := range rNTT { + if v != rBarretNTT[i] { + t.Errorf("expected %v, got %v", v, rBarretNTT[i]) } } } func TestInverseNTT(t *testing.T) { r := randomRingElement() - r1 := r - r2 := ntt(r1) - r3 := inverseNTT(r2) + ret := inverseNTT(ntt(r)) for i, v := range r { - if v != fieldReduce(uint64(r3[i])) { - t.Errorf("expected %v, got %v", v, fieldReduce(uint64(r3[i]))) + if v != fieldReduce(uint64(ret[i])) { + t.Errorf("expected %v, got %v", v, fieldReduce(uint64(ret[i]))) } } } func TestInverseBarrettNTT(t *testing.T) { r := randomRingElement() - r1 := r - r2 := barrettNTT(r1) - r3 := inverseBarrettNTT(r2) + ret := inverseBarrettNTT(barrettNTT(r)) for i, v := range r { - if v != r3[i] { - t.Errorf("expected %v, got %v", v, r3[i]) + if v != ret[i] { + t.Errorf("expected %v, got %v", v, ret[i]) } } } // this is the real use case for NTT: // -// - convert to NTT -// - multiply in NTT -// - inverse NTT +// - convert to NTT +// - multiply in NTT +// - inverse NTT func TestInverseNTTWithMultiply(t *testing.T) { r1 := randomRingElement() r2 := randomRingElement() // Montgomery Method - r11 := r1 - r111 := ntt(r11) - r22 := r2 - r222 := ntt(r22) - r31 := nttMul(r111, r222) - r32 := inverseNTT(r31) + ret1 := inverseNTT(nttMul(ntt(r1), ntt(r2))) // Barrett Method - b11 := barrettNTT(r1) - b22 := barrettNTT(r2) - r33 := nttBarrettMul(b11, b22) - r34 := inverseBarrettNTT(r33) + ret2 := inverseBarrettNTT(nttBarrettMul(barrettNTT(r1), barrettNTT(r2))) // Check if the results are equal - for i := range r32 { - if r32[i] != r34[i] { - t.Errorf("expected %v, got %v", r34[i], r32[i]) + for i := range ret1 { + if ret1[i] != ret2[i] { + t.Errorf("expected %v, got %v", ret2[i], ret1[i]) } } } diff --git a/mldsa/mldsa44_test.go b/mldsa/mldsa44_test.go index 6339395..7dbb7e2 100644 --- a/mldsa/mldsa44_test.go +++ b/mldsa/mldsa44_test.go @@ -339,3 +339,23 @@ func BenchmarkSign44(b *testing.B) { } } } + +func BenchmarkVerify44(b *testing.B) { + c := sigVer44InternalProjectionCases[0] + pk, _ := hex.DecodeString(c.pk) + sig, _ := hex.DecodeString(c.sig) + msg, _ := hex.DecodeString(c.message) + ctx, _ := hex.DecodeString(c.context) + + pub, err := NewPublicKey44(pk) + if err != nil { + b.Fatalf("NewPublicKey44 failed: %v", err) + } + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + if !pub.Verify(sig, msg, ctx) { + b.Errorf("Verify failed") + } + } +} diff --git a/mldsa/mldsa65_test.go b/mldsa/mldsa65_test.go index 453feaf..56d1f06 100644 --- a/mldsa/mldsa65_test.go +++ b/mldsa/mldsa65_test.go @@ -329,3 +329,23 @@ func BenchmarkSign65(b *testing.B) { } } } + +func BenchmarkVerify65(b *testing.B) { + c := sigVer65InternalProjectionCases[1] + pk, _ := hex.DecodeString(c.pk) + sig, _ := hex.DecodeString(c.sig) + msg, _ := hex.DecodeString(c.message) + ctx, _ := hex.DecodeString(c.context) + + pub, err := NewPublicKey65(pk) + if err != nil { + b.Fatalf("NewPublicKey65 failed: %v", err) + } + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + if !pub.Verify(sig, msg, ctx) { + b.Errorf("Verify failed") + } + } +} diff --git a/mldsa/mldsa87_test.go b/mldsa/mldsa87_test.go index f2d55ca..95554d6 100644 --- a/mldsa/mldsa87_test.go +++ b/mldsa/mldsa87_test.go @@ -289,3 +289,23 @@ func BenchmarkSign87(b *testing.B) { } } } + +func BenchmarkVerify87(b *testing.B) { + c := sigVer87InternalProjectionCases[2] + pk, _ := hex.DecodeString(c.pk) + sig, _ := hex.DecodeString(c.sig) + msg, _ := hex.DecodeString(c.message) + ctx, _ := hex.DecodeString(c.context) + + pub, err := NewPublicKey87(pk) + if err != nil { + b.Fatalf("NewPublicKey87 failed: %v", err) + } + b.ReportAllocs() + b.ResetTimer() + for b.Loop() { + if !pub.Verify(sig, msg, ctx) { + b.Errorf("Verify failed") + } + } +}