mldsa: supplement test cases and comments

This commit is contained in:
Sun Yimin 2025-05-30 10:06:23 +08:00 committed by GitHub
parent 8f0bd765ca
commit 8fc001fb45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 125 additions and 29 deletions

View File

@ -97,3 +97,18 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement {
return fieldElement(r1 - 1)
}
}
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
for i := range ct0 {
for j := range ct0[i] {
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
}
}
}
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
rPlusZ := fieldSub(w, cs2)
r := fieldAdd(rPlusZ, ct0)
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPlusZ, gamma2)))))
}

View File

@ -40,9 +40,9 @@ func fieldSub(a, b fieldElement) fieldElement {
}
const (
qInv = 58728449
qNegInv = 4236238847
r = 4193792 // 2^32 mod q
qInv = 58728449 // q^-1 satisfies: q^-1 * q = 1 mod 2^32
qNegInv = 4236238847 // inverse of -q modulo 2^32
r = 4193792 // 2^32 mod q
)
func fieldReduce(a uint64) fieldElement {
@ -90,7 +90,7 @@ type nttElement [n]fieldElement
// As this implementation uses montgomery form with a multiplier of 2^32.
// The values need to be transformed i.e.
//
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q)))
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q))) = (zeta[k] * r^2) mod q
var zetasMontgomery = [n]fieldElement{
4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
@ -134,7 +134,7 @@ func ntt(f ringElement) nttElement {
// len: 128, 64, 32, ..., 1
for len := 128; len >= 1; len /= 2 {
// start
for start := 0; start < 256; start += 2 * len {
for start := 0; start < n; start += 2 * len {
zeta := zetasMontgomery[k]
k++
// Bounds check elimination hint.
@ -154,8 +154,8 @@ func ntt(f ringElement) nttElement {
// It implements NTT⁻¹, according to FIPS 204, Algorithm 42.
func inverseNTT(f nttElement) ringElement {
k := 255
for len := 1; len < 256; len *= 2 {
for start := 0; start < 256; start += 2 * len {
for len := 1; len < n; len *= 2 {
for start := 0; start < n; start += 2 * len {
zeta := q - zetasMontgomery[k]
k--
// Bounds check elimination hint.
@ -240,17 +240,3 @@ func vectorCountOnes(a []ringElement) int {
return oneCount
}
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
for i := range ct0 {
for j := range ct0[i] {
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
}
}
}
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
rPulusZ := fieldSub(w, cs2)
r := fieldAdd(rPulusZ, ct0)
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPulusZ, gamma2)))))
}

View File

@ -105,10 +105,10 @@ func inverseBarrettNTT(f nttElement) ringElement {
return ringElement(f)
}
//func nttBarrettMul(f, g nttElement) nttElement {
// var ret nttElement
// for i, v := range f {
// ret[i] = fieldBarrettMul(v, g[i])
// }
// return ret
//}
func nttBarrettMul(f, g nttElement) nttElement {
var ret nttElement
for i, v := range f {
ret[i] = fieldBarrettMul(v, g[i])
}
return ret
}

View File

@ -8,10 +8,74 @@ package mldsa
import (
"fmt"
"math/big"
mathrand "math/rand/v2"
"testing"
)
func bitreverse(x byte) byte {
var y byte
for i := range 8 {
y |= (x & 1) << (7 - i)
x >>= 1
}
return y
}
func TestConstants(t *testing.T) {
q1 := big.NewInt(q)
a := big.NewInt(1 << 32)
q1Inv := new(big.Int)
q1Inv.ModInverse(q1, a)
if q1Inv.Cmp(big.NewInt(int64(qInv))) != 0 {
t.Fatalf("q^-1 mod 2^32 = %d, expected %d", q1, qInv)
}
q1Neg := new(big.Int)
q1Neg.Sub(a, q1)
q1NegInv := new(big.Int)
q1NegInv.ModInverse(q1Neg, a)
if q1NegInv.Cmp(big.NewInt(int64(qNegInv))) != 0 {
t.Fatalf("-q^-1 mod 2^32 = %d, expected %d", q1Neg, qNegInv)
}
r1 := new(big.Int)
r1.Mod(a, q1)
if r1.Cmp(big.NewInt(int64(r))) != 0 {
t.Fatalf("r = 2^32 mod q = %d, expected %d", r1, r)
}
dgreeInv := new(big.Int)
dgreeInv.ModInverse(big.NewInt(int64(256)), q1)
dgreeInv.Mul(dgreeInv, r1)
dgreeInv.Mul(dgreeInv, r1)
dgreeInv.Mod(dgreeInv, q1)
if dgreeInv.Int64() != 41978 {
t.Fatalf("dgreeInv = ((256^(-1) mod q) * r^2) mod q = %d, expected 41978", dgreeInv)
}
// test zetas
zeta := big.NewInt(1753)
for i := 1; i < 256; i++ {
bitRev := bitreverse(byte(i))
zetaV := new(big.Int).Exp(zeta, big.NewInt(int64(bitRev)), q1)
if uint32(zetaV.Int64()) != uint32(zetas[i]) {
t.Fatalf("zetas[%d] = %d, expected %d", i, uint32(zetaV.Int64()), zetas[i])
}
}
// test zetasMontgomery
for i, z := range zetas {
zMont := big.NewInt(int64(z))
zMont.Mul(zMont, r1)
zMont.Mod(zMont, q1)
if zMont.Cmp(big.NewInt(int64(zetasMontgomery[i]))) != 0 {
t.Fatalf("zetasMontgomery[%d] = %d, expected %d", i, zMont, zetasMontgomery[i])
}
}
}
func TestFieldAdd(t *testing.T) {
for a := fieldElement(q - 1000); a < q; a++ {
for b := fieldElement(q - 1000); b < q; b++ {
@ -108,6 +172,37 @@ func TestInverseBarrettNTT(t *testing.T) {
}
}
// this is the real use case for NTT:
//
// - convert to NTT
// - multiply in NTT
// - inverse NTT
func TestInverseNTTWithMultiply(t *testing.T) {
r1 := randomRingElement()
r2 := randomRingElement()
// Montgomery Method
r11 := r1
r111 := ntt(r11)
r22 := r2
r222 := ntt(r22)
r31 := nttMul(r111, r222)
r32 := inverseNTT(r31)
// Barrett Method
b11 := barrettNTT(r1)
b22 := barrettNTT(r2)
r33 := nttBarrettMul(b11, b22)
r34 := inverseBarrettNTT(r33)
// Check if the results are equal
for i := range r32 {
if r32[i] != r34[i] {
t.Errorf("expected %v, got %v", r34[i], r32[i])
}
}
}
func TestInfinityNorm(t *testing.T) {
cases := []struct {
input fieldElement

View File

@ -25,7 +25,7 @@ import (
const (
// ML-DSA global constants.
n = 256
n = 256 // # of coefficients in the polynomials
q = 8380417 // 2^23 - 2^13 + 1
qMinus1Div2 = (q - 1) / 2
d = 13 // # of dropped bits from t