From 8fc001fb45e0a656bb3296b3437939c27416929a Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 30 May 2025 10:06:23 +0800 Subject: [PATCH] mldsa: supplement test cases and comments --- mldsa/compress.go | 15 +++++++ mldsa/field.go | 28 ++++--------- mldsa/field_barrett.go | 14 +++---- mldsa/field_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++ mldsa/mldsa44.go | 2 +- 5 files changed, 125 insertions(+), 29 deletions(-) diff --git a/mldsa/compress.go b/mldsa/compress.go index 37f8e5d..5f45caf 100644 --- a/mldsa/compress.go +++ b/mldsa/compress.go @@ -97,3 +97,18 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement { return fieldElement(r1 - 1) } } + +func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) { + for i := range ct0 { + for j := range ct0[i] { + hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2) + } + } +} + +func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement { + rPlusZ := fieldSub(w, cs2) + r := fieldAdd(rPlusZ, ct0) + + return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPlusZ, gamma2))))) +} diff --git a/mldsa/field.go b/mldsa/field.go index 02c7345..aafbafb 100644 --- a/mldsa/field.go +++ b/mldsa/field.go @@ -40,9 +40,9 @@ func fieldSub(a, b fieldElement) fieldElement { } const ( - qInv = 58728449 - qNegInv = 4236238847 - r = 4193792 // 2^32 mod q + qInv = 58728449 // q^-1 satisfies: q^-1 * q = 1 mod 2^32 + qNegInv = 4236238847 // inverse of -q modulo 2^32 + r = 4193792 // 2^32 mod q ) func fieldReduce(a uint64) fieldElement { @@ -90,7 +90,7 @@ type nttElement [n]fieldElement // As this implementation uses montgomery form with a multiplier of 2^32. // The values need to be transformed i.e. // -// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q))) +// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q))) = (zeta[k] * r^2) mod q var zetasMontgomery = [n]fieldElement{ 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, @@ -134,7 +134,7 @@ func ntt(f ringElement) nttElement { // len: 128, 64, 32, ..., 1 for len := 128; len >= 1; len /= 2 { // start - for start := 0; start < 256; start += 2 * len { + for start := 0; start < n; start += 2 * len { zeta := zetasMontgomery[k] k++ // Bounds check elimination hint. @@ -154,8 +154,8 @@ func ntt(f ringElement) nttElement { // It implements NTT⁻¹, according to FIPS 204, Algorithm 42. func inverseNTT(f nttElement) ringElement { k := 255 - for len := 1; len < 256; len *= 2 { - for start := 0; start < 256; start += 2 * len { + for len := 1; len < n; len *= 2 { + for start := 0; start < n; start += 2 * len { zeta := q - zetasMontgomery[k] k-- // Bounds check elimination hint. @@ -240,17 +240,3 @@ func vectorCountOnes(a []ringElement) int { return oneCount } -func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) { - for i := range ct0 { - for j := range ct0[i] { - hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2) - } - } -} - -func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement { - rPulusZ := fieldSub(w, cs2) - r := fieldAdd(rPulusZ, ct0) - - return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPulusZ, gamma2))))) -} diff --git a/mldsa/field_barrett.go b/mldsa/field_barrett.go index 74494e7..f5316bb 100644 --- a/mldsa/field_barrett.go +++ b/mldsa/field_barrett.go @@ -105,10 +105,10 @@ func inverseBarrettNTT(f nttElement) ringElement { return ringElement(f) } -//func nttBarrettMul(f, g nttElement) nttElement { -// var ret nttElement -// for i, v := range f { -// ret[i] = fieldBarrettMul(v, g[i]) -// } -// return ret -//} +func nttBarrettMul(f, g nttElement) nttElement { + var ret nttElement + for i, v := range f { + ret[i] = fieldBarrettMul(v, g[i]) + } + return ret +} diff --git a/mldsa/field_test.go b/mldsa/field_test.go index ea398a4..f07d59d 100644 --- a/mldsa/field_test.go +++ b/mldsa/field_test.go @@ -8,10 +8,74 @@ package mldsa import ( "fmt" + "math/big" mathrand "math/rand/v2" "testing" ) +func bitreverse(x byte) byte { + var y byte + for i := range 8 { + y |= (x & 1) << (7 - i) + x >>= 1 + } + return y +} + +func TestConstants(t *testing.T) { + q1 := big.NewInt(q) + a := big.NewInt(1 << 32) + + q1Inv := new(big.Int) + q1Inv.ModInverse(q1, a) + if q1Inv.Cmp(big.NewInt(int64(qInv))) != 0 { + t.Fatalf("q^-1 mod 2^32 = %d, expected %d", q1, qInv) + } + + q1Neg := new(big.Int) + q1Neg.Sub(a, q1) + q1NegInv := new(big.Int) + q1NegInv.ModInverse(q1Neg, a) + if q1NegInv.Cmp(big.NewInt(int64(qNegInv))) != 0 { + t.Fatalf("-q^-1 mod 2^32 = %d, expected %d", q1Neg, qNegInv) + } + + r1 := new(big.Int) + r1.Mod(a, q1) + if r1.Cmp(big.NewInt(int64(r))) != 0 { + t.Fatalf("r = 2^32 mod q = %d, expected %d", r1, r) + } + + dgreeInv := new(big.Int) + dgreeInv.ModInverse(big.NewInt(int64(256)), q1) + dgreeInv.Mul(dgreeInv, r1) + dgreeInv.Mul(dgreeInv, r1) + dgreeInv.Mod(dgreeInv, q1) + if dgreeInv.Int64() != 41978 { + t.Fatalf("dgreeInv = ((256^(-1) mod q) * r^2) mod q = %d, expected 41978", dgreeInv) + } + + // test zetas + zeta := big.NewInt(1753) + for i := 1; i < 256; i++ { + bitRev := bitreverse(byte(i)) + zetaV := new(big.Int).Exp(zeta, big.NewInt(int64(bitRev)), q1) + if uint32(zetaV.Int64()) != uint32(zetas[i]) { + t.Fatalf("zetas[%d] = %d, expected %d", i, uint32(zetaV.Int64()), zetas[i]) + } + } + + // test zetasMontgomery + for i, z := range zetas { + zMont := big.NewInt(int64(z)) + zMont.Mul(zMont, r1) + zMont.Mod(zMont, q1) + if zMont.Cmp(big.NewInt(int64(zetasMontgomery[i]))) != 0 { + t.Fatalf("zetasMontgomery[%d] = %d, expected %d", i, zMont, zetasMontgomery[i]) + } + } +} + func TestFieldAdd(t *testing.T) { for a := fieldElement(q - 1000); a < q; a++ { for b := fieldElement(q - 1000); b < q; b++ { @@ -108,6 +172,37 @@ func TestInverseBarrettNTT(t *testing.T) { } } +// this is the real use case for NTT: +// +// - convert to NTT +// - multiply in NTT +// - inverse NTT +func TestInverseNTTWithMultiply(t *testing.T) { + r1 := randomRingElement() + r2 := randomRingElement() + + // Montgomery Method + r11 := r1 + r111 := ntt(r11) + r22 := r2 + r222 := ntt(r22) + r31 := nttMul(r111, r222) + r32 := inverseNTT(r31) + + // Barrett Method + b11 := barrettNTT(r1) + b22 := barrettNTT(r2) + r33 := nttBarrettMul(b11, b22) + r34 := inverseBarrettNTT(r33) + + // Check if the results are equal + for i := range r32 { + if r32[i] != r34[i] { + t.Errorf("expected %v, got %v", r34[i], r32[i]) + } + } +} + func TestInfinityNorm(t *testing.T) { cases := []struct { input fieldElement diff --git a/mldsa/mldsa44.go b/mldsa/mldsa44.go index d43320e..cc6b7d7 100644 --- a/mldsa/mldsa44.go +++ b/mldsa/mldsa44.go @@ -25,7 +25,7 @@ import ( const ( // ML-DSA global constants. - n = 256 + n = 256 // # of coefficients in the polynomials q = 8380417 // 2^23 - 2^13 + 1 qMinus1Div2 = (q - 1) / 2 d = 13 // # of dropped bits from t