2025-05-07 10:05:13 +08:00
|
|
|
// Copyright 2025 Sun Yimin. All rights reserved.
|
|
|
|
// Use of this source code is governed by a MIT-style
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
|
2025-05-07 10:09:48 +08:00
|
|
|
//go:build go1.24
|
|
|
|
|
2025-05-07 10:05:13 +08:00
|
|
|
package mldsa
|
|
|
|
|
|
|
|
import (
|
2025-05-30 10:06:23 +08:00
|
|
|
"math/big"
|
2025-05-07 10:05:13 +08:00
|
|
|
mathrand "math/rand/v2"
|
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
2025-05-30 10:06:23 +08:00
|
|
|
func bitreverse(x byte) byte {
|
|
|
|
var y byte
|
|
|
|
for i := range 8 {
|
|
|
|
y |= (x & 1) << (7 - i)
|
|
|
|
x >>= 1
|
|
|
|
}
|
|
|
|
return y
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestConstants(t *testing.T) {
|
|
|
|
q1 := big.NewInt(q)
|
|
|
|
a := big.NewInt(1 << 32)
|
|
|
|
|
|
|
|
q1Inv := new(big.Int)
|
|
|
|
q1Inv.ModInverse(q1, a)
|
|
|
|
if q1Inv.Cmp(big.NewInt(int64(qInv))) != 0 {
|
|
|
|
t.Fatalf("q^-1 mod 2^32 = %d, expected %d", q1, qInv)
|
|
|
|
}
|
|
|
|
|
|
|
|
q1Neg := new(big.Int)
|
|
|
|
q1Neg.Sub(a, q1)
|
|
|
|
q1NegInv := new(big.Int)
|
|
|
|
q1NegInv.ModInverse(q1Neg, a)
|
|
|
|
if q1NegInv.Cmp(big.NewInt(int64(qNegInv))) != 0 {
|
|
|
|
t.Fatalf("-q^-1 mod 2^32 = %d, expected %d", q1Neg, qNegInv)
|
|
|
|
}
|
|
|
|
|
|
|
|
r1 := new(big.Int)
|
|
|
|
r1.Mod(a, q1)
|
|
|
|
if r1.Cmp(big.NewInt(int64(r))) != 0 {
|
|
|
|
t.Fatalf("r = 2^32 mod q = %d, expected %d", r1, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
dgreeInv := new(big.Int)
|
|
|
|
dgreeInv.ModInverse(big.NewInt(int64(256)), q1)
|
|
|
|
dgreeInv.Mul(dgreeInv, r1)
|
|
|
|
dgreeInv.Mul(dgreeInv, r1)
|
|
|
|
dgreeInv.Mod(dgreeInv, q1)
|
|
|
|
if dgreeInv.Int64() != 41978 {
|
|
|
|
t.Fatalf("dgreeInv = ((256^(-1) mod q) * r^2) mod q = %d, expected 41978", dgreeInv)
|
|
|
|
}
|
|
|
|
|
|
|
|
// test zetas
|
|
|
|
zeta := big.NewInt(1753)
|
|
|
|
for i := 1; i < 256; i++ {
|
|
|
|
bitRev := bitreverse(byte(i))
|
|
|
|
zetaV := new(big.Int).Exp(zeta, big.NewInt(int64(bitRev)), q1)
|
|
|
|
if uint32(zetaV.Int64()) != uint32(zetas[i]) {
|
|
|
|
t.Fatalf("zetas[%d] = %d, expected %d", i, uint32(zetaV.Int64()), zetas[i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// test zetasMontgomery
|
|
|
|
for i, z := range zetas {
|
|
|
|
zMont := big.NewInt(int64(z))
|
|
|
|
zMont.Mul(zMont, r1)
|
|
|
|
zMont.Mod(zMont, q1)
|
|
|
|
if zMont.Cmp(big.NewInt(int64(zetasMontgomery[i]))) != 0 {
|
|
|
|
t.Fatalf("zetasMontgomery[%d] = %d, expected %d", i, zMont, zetasMontgomery[i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-05-07 10:05:13 +08:00
|
|
|
func TestFieldAdd(t *testing.T) {
|
|
|
|
for a := fieldElement(q - 1000); a < q; a++ {
|
|
|
|
for b := fieldElement(q - 1000); b < q; b++ {
|
|
|
|
got := fieldAdd(a, b)
|
|
|
|
exp := (a + b) % q
|
|
|
|
if got != exp {
|
|
|
|
t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestFieldSub(t *testing.T) {
|
|
|
|
for a := fieldElement(0); a < 2000; a++ {
|
|
|
|
for b := fieldElement(q - 1000); b < q; b++ {
|
|
|
|
got := fieldSub(a, b)
|
|
|
|
exp := (a - b + q) % q
|
|
|
|
if got != exp {
|
|
|
|
t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestFieldMul(t *testing.T) {
|
|
|
|
for a := fieldElement(q - 1000); a < q; a++ {
|
|
|
|
for b := fieldElement(q - 1000); b < q; b++ {
|
|
|
|
got := fieldMul(fieldElement((uint64(a)*uint64(r))%q), b)
|
|
|
|
exp := fieldElement((uint64(a) * uint64(b)) % q)
|
|
|
|
if got != exp {
|
|
|
|
t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestFieldBarrettMul(t *testing.T) {
|
|
|
|
for a := fieldElement(q - 1000); a < q; a++ {
|
|
|
|
for b := fieldElement(q - 1000); b < q; b++ {
|
|
|
|
got := fieldBarrettMul(a, b)
|
|
|
|
exp := fieldElement((uint64(a) * uint64(b)) % q)
|
|
|
|
if got != exp {
|
|
|
|
t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func randomRingElement() ringElement {
|
|
|
|
var r ringElement
|
|
|
|
for i := range r {
|
|
|
|
r[i] = fieldElement(mathrand.IntN(q))
|
|
|
|
}
|
|
|
|
return r
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNTT(t *testing.T) {
|
|
|
|
r := randomRingElement()
|
2025-05-30 15:25:37 +08:00
|
|
|
rNTT := ntt(r)
|
|
|
|
rBarretNTT := barrettNTT(r)
|
|
|
|
for i, v := range rNTT {
|
|
|
|
if v != rBarretNTT[i] {
|
|
|
|
t.Errorf("expected %v, got %v", v, rBarretNTT[i])
|
2025-05-07 10:05:13 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestInverseNTT(t *testing.T) {
|
|
|
|
r := randomRingElement()
|
2025-05-30 15:25:37 +08:00
|
|
|
ret := inverseNTT(ntt(r))
|
2025-05-07 10:05:13 +08:00
|
|
|
for i, v := range r {
|
2025-05-30 15:25:37 +08:00
|
|
|
if v != fieldReduce(uint64(ret[i])) {
|
|
|
|
t.Errorf("expected %v, got %v", v, fieldReduce(uint64(ret[i])))
|
2025-05-07 10:05:13 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestInverseBarrettNTT(t *testing.T) {
|
|
|
|
r := randomRingElement()
|
2025-05-30 15:25:37 +08:00
|
|
|
ret := inverseBarrettNTT(barrettNTT(r))
|
2025-05-07 10:05:13 +08:00
|
|
|
for i, v := range r {
|
2025-05-30 15:25:37 +08:00
|
|
|
if v != ret[i] {
|
|
|
|
t.Errorf("expected %v, got %v", v, ret[i])
|
2025-05-07 10:05:13 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-05-30 10:06:23 +08:00
|
|
|
// this is the real use case for NTT:
|
|
|
|
//
|
2025-05-30 15:25:37 +08:00
|
|
|
// - convert to NTT
|
|
|
|
// - multiply in NTT
|
|
|
|
// - inverse NTT
|
2025-05-30 10:06:23 +08:00
|
|
|
func TestInverseNTTWithMultiply(t *testing.T) {
|
|
|
|
r1 := randomRingElement()
|
|
|
|
r2 := randomRingElement()
|
|
|
|
|
|
|
|
// Montgomery Method
|
2025-05-30 15:25:37 +08:00
|
|
|
ret1 := inverseNTT(nttMul(ntt(r1), ntt(r2)))
|
2025-05-30 10:06:23 +08:00
|
|
|
|
|
|
|
// Barrett Method
|
2025-05-30 15:25:37 +08:00
|
|
|
ret2 := inverseBarrettNTT(nttBarrettMul(barrettNTT(r1), barrettNTT(r2)))
|
2025-05-30 10:06:23 +08:00
|
|
|
|
|
|
|
// Check if the results are equal
|
2025-05-30 15:25:37 +08:00
|
|
|
for i := range ret1 {
|
|
|
|
if ret1[i] != ret2[i] {
|
|
|
|
t.Errorf("expected %v, got %v", ret2[i], ret1[i])
|
2025-05-30 10:06:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-05-07 10:05:13 +08:00
|
|
|
func TestInfinityNorm(t *testing.T) {
|
|
|
|
cases := []struct {
|
|
|
|
input fieldElement
|
|
|
|
expected uint32
|
|
|
|
}{
|
|
|
|
{0, 0},
|
|
|
|
{1, 1},
|
|
|
|
{(q - 1) / 2, (q - 1) / 2},
|
|
|
|
{(q-1)/2 + 1, q - 1 - (q-1)/2},
|
|
|
|
{q - 1, 1},
|
|
|
|
}
|
|
|
|
for _, c := range cases {
|
|
|
|
got := infinityNorm(c.input)
|
|
|
|
if got != c.expected {
|
|
|
|
t.Fatalf("infinityNorm(%d) = %d, expected %d", c.input, got, c.expected)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestPolyInfinityNorm(t *testing.T) {
|
|
|
|
r := randomRingElement()
|
|
|
|
got := polyInfinityNorm(r, 0)
|
|
|
|
var expected int
|
|
|
|
|
|
|
|
for _, v := range r {
|
|
|
|
if v > qMinus1Div2 {
|
|
|
|
v = q - v
|
|
|
|
}
|
|
|
|
if int(v) > expected {
|
|
|
|
expected = int(v)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if got != expected {
|
|
|
|
t.Fatalf("polyInfinityNorm(%v) = %d, expected %d", r, got, expected)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestInfinityNormSigned(t *testing.T) {
|
|
|
|
cases := []struct {
|
|
|
|
input int32
|
|
|
|
expected int
|
|
|
|
}{
|
|
|
|
{0, 0},
|
|
|
|
{1, 1},
|
|
|
|
{-1, 1},
|
|
|
|
{-2, 2},
|
|
|
|
}
|
|
|
|
for _, c := range cases {
|
|
|
|
got := infinityNormSigned(c.input)
|
|
|
|
if got != c.expected {
|
|
|
|
t.Fatalf("infinityNormSigned(%d) = %d, expected %d", c.input, got, c.expected)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestPolyInfinityNormSigned(t *testing.T) {
|
|
|
|
cases := []struct {
|
|
|
|
input []int32
|
|
|
|
expected int
|
|
|
|
}{
|
|
|
|
{[]int32{0, 0, 0}, 0},
|
|
|
|
{[]int32{1, 2, 3}, 3},
|
|
|
|
{[]int32{0, -1, -2, -3, 2}, 3},
|
|
|
|
}
|
|
|
|
for _, c := range cases {
|
|
|
|
got := polyInfinityNormSigned(c.input, 0)
|
|
|
|
if got != c.expected {
|
|
|
|
t.Fatalf("polyInfinityNormSigned(%v) = %d, expected %d", c.input, got, c.expected)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|