mirror of
https://github.com/emmansun/gmsm.git
synced 2025-06-28 08:23:26 +08:00
mldsa: supplement test cases and comments
This commit is contained in:
parent
8f0bd765ca
commit
8fc001fb45
@ -97,3 +97,18 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement {
|
||||
return fieldElement(r1 - 1)
|
||||
}
|
||||
}
|
||||
|
||||
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
|
||||
for i := range ct0 {
|
||||
for j := range ct0[i] {
|
||||
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
|
||||
rPlusZ := fieldSub(w, cs2)
|
||||
r := fieldAdd(rPlusZ, ct0)
|
||||
|
||||
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPlusZ, gamma2)))))
|
||||
}
|
||||
|
@ -40,8 +40,8 @@ func fieldSub(a, b fieldElement) fieldElement {
|
||||
}
|
||||
|
||||
const (
|
||||
qInv = 58728449
|
||||
qNegInv = 4236238847
|
||||
qInv = 58728449 // q^-1 satisfies: q^-1 * q = 1 mod 2^32
|
||||
qNegInv = 4236238847 // inverse of -q modulo 2^32
|
||||
r = 4193792 // 2^32 mod q
|
||||
)
|
||||
|
||||
@ -90,7 +90,7 @@ type nttElement [n]fieldElement
|
||||
// As this implementation uses montgomery form with a multiplier of 2^32.
|
||||
// The values need to be transformed i.e.
|
||||
//
|
||||
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q)))
|
||||
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q))) = (zeta[k] * r^2) mod q
|
||||
var zetasMontgomery = [n]fieldElement{
|
||||
4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
|
||||
1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
|
||||
@ -134,7 +134,7 @@ func ntt(f ringElement) nttElement {
|
||||
// len: 128, 64, 32, ..., 1
|
||||
for len := 128; len >= 1; len /= 2 {
|
||||
// start
|
||||
for start := 0; start < 256; start += 2 * len {
|
||||
for start := 0; start < n; start += 2 * len {
|
||||
zeta := zetasMontgomery[k]
|
||||
k++
|
||||
// Bounds check elimination hint.
|
||||
@ -154,8 +154,8 @@ func ntt(f ringElement) nttElement {
|
||||
// It implements NTT⁻¹, according to FIPS 204, Algorithm 42.
|
||||
func inverseNTT(f nttElement) ringElement {
|
||||
k := 255
|
||||
for len := 1; len < 256; len *= 2 {
|
||||
for start := 0; start < 256; start += 2 * len {
|
||||
for len := 1; len < n; len *= 2 {
|
||||
for start := 0; start < n; start += 2 * len {
|
||||
zeta := q - zetasMontgomery[k]
|
||||
k--
|
||||
// Bounds check elimination hint.
|
||||
@ -240,17 +240,3 @@ func vectorCountOnes(a []ringElement) int {
|
||||
return oneCount
|
||||
}
|
||||
|
||||
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
|
||||
for i := range ct0 {
|
||||
for j := range ct0[i] {
|
||||
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
|
||||
rPulusZ := fieldSub(w, cs2)
|
||||
r := fieldAdd(rPulusZ, ct0)
|
||||
|
||||
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPulusZ, gamma2)))))
|
||||
}
|
||||
|
@ -105,10 +105,10 @@ func inverseBarrettNTT(f nttElement) ringElement {
|
||||
return ringElement(f)
|
||||
}
|
||||
|
||||
//func nttBarrettMul(f, g nttElement) nttElement {
|
||||
// var ret nttElement
|
||||
// for i, v := range f {
|
||||
// ret[i] = fieldBarrettMul(v, g[i])
|
||||
// }
|
||||
// return ret
|
||||
//}
|
||||
func nttBarrettMul(f, g nttElement) nttElement {
|
||||
var ret nttElement
|
||||
for i, v := range f {
|
||||
ret[i] = fieldBarrettMul(v, g[i])
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
@ -8,10 +8,74 @@ package mldsa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
mathrand "math/rand/v2"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func bitreverse(x byte) byte {
|
||||
var y byte
|
||||
for i := range 8 {
|
||||
y |= (x & 1) << (7 - i)
|
||||
x >>= 1
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
q1 := big.NewInt(q)
|
||||
a := big.NewInt(1 << 32)
|
||||
|
||||
q1Inv := new(big.Int)
|
||||
q1Inv.ModInverse(q1, a)
|
||||
if q1Inv.Cmp(big.NewInt(int64(qInv))) != 0 {
|
||||
t.Fatalf("q^-1 mod 2^32 = %d, expected %d", q1, qInv)
|
||||
}
|
||||
|
||||
q1Neg := new(big.Int)
|
||||
q1Neg.Sub(a, q1)
|
||||
q1NegInv := new(big.Int)
|
||||
q1NegInv.ModInverse(q1Neg, a)
|
||||
if q1NegInv.Cmp(big.NewInt(int64(qNegInv))) != 0 {
|
||||
t.Fatalf("-q^-1 mod 2^32 = %d, expected %d", q1Neg, qNegInv)
|
||||
}
|
||||
|
||||
r1 := new(big.Int)
|
||||
r1.Mod(a, q1)
|
||||
if r1.Cmp(big.NewInt(int64(r))) != 0 {
|
||||
t.Fatalf("r = 2^32 mod q = %d, expected %d", r1, r)
|
||||
}
|
||||
|
||||
dgreeInv := new(big.Int)
|
||||
dgreeInv.ModInverse(big.NewInt(int64(256)), q1)
|
||||
dgreeInv.Mul(dgreeInv, r1)
|
||||
dgreeInv.Mul(dgreeInv, r1)
|
||||
dgreeInv.Mod(dgreeInv, q1)
|
||||
if dgreeInv.Int64() != 41978 {
|
||||
t.Fatalf("dgreeInv = ((256^(-1) mod q) * r^2) mod q = %d, expected 41978", dgreeInv)
|
||||
}
|
||||
|
||||
// test zetas
|
||||
zeta := big.NewInt(1753)
|
||||
for i := 1; i < 256; i++ {
|
||||
bitRev := bitreverse(byte(i))
|
||||
zetaV := new(big.Int).Exp(zeta, big.NewInt(int64(bitRev)), q1)
|
||||
if uint32(zetaV.Int64()) != uint32(zetas[i]) {
|
||||
t.Fatalf("zetas[%d] = %d, expected %d", i, uint32(zetaV.Int64()), zetas[i])
|
||||
}
|
||||
}
|
||||
|
||||
// test zetasMontgomery
|
||||
for i, z := range zetas {
|
||||
zMont := big.NewInt(int64(z))
|
||||
zMont.Mul(zMont, r1)
|
||||
zMont.Mod(zMont, q1)
|
||||
if zMont.Cmp(big.NewInt(int64(zetasMontgomery[i]))) != 0 {
|
||||
t.Fatalf("zetasMontgomery[%d] = %d, expected %d", i, zMont, zetasMontgomery[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldAdd(t *testing.T) {
|
||||
for a := fieldElement(q - 1000); a < q; a++ {
|
||||
for b := fieldElement(q - 1000); b < q; b++ {
|
||||
@ -108,6 +172,37 @@ func TestInverseBarrettNTT(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// this is the real use case for NTT:
|
||||
//
|
||||
// - convert to NTT
|
||||
// - multiply in NTT
|
||||
// - inverse NTT
|
||||
func TestInverseNTTWithMultiply(t *testing.T) {
|
||||
r1 := randomRingElement()
|
||||
r2 := randomRingElement()
|
||||
|
||||
// Montgomery Method
|
||||
r11 := r1
|
||||
r111 := ntt(r11)
|
||||
r22 := r2
|
||||
r222 := ntt(r22)
|
||||
r31 := nttMul(r111, r222)
|
||||
r32 := inverseNTT(r31)
|
||||
|
||||
// Barrett Method
|
||||
b11 := barrettNTT(r1)
|
||||
b22 := barrettNTT(r2)
|
||||
r33 := nttBarrettMul(b11, b22)
|
||||
r34 := inverseBarrettNTT(r33)
|
||||
|
||||
// Check if the results are equal
|
||||
for i := range r32 {
|
||||
if r32[i] != r34[i] {
|
||||
t.Errorf("expected %v, got %v", r34[i], r32[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInfinityNorm(t *testing.T) {
|
||||
cases := []struct {
|
||||
input fieldElement
|
||||
|
@ -25,7 +25,7 @@ import (
|
||||
|
||||
const (
|
||||
// ML-DSA global constants.
|
||||
n = 256
|
||||
n = 256 // # of coefficients in the polynomials
|
||||
q = 8380417 // 2^23 - 2^13 + 1
|
||||
qMinus1Div2 = (q - 1) / 2
|
||||
d = 13 // # of dropped bits from t
|
||||
|
Loading…
x
Reference in New Issue
Block a user