mirror of
https://github.com/emmansun/gmsm.git
synced 2025-06-28 08:23:26 +08:00
mldsa: supplement test cases and comments
This commit is contained in:
parent
8f0bd765ca
commit
8fc001fb45
@ -97,3 +97,18 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement {
|
|||||||
return fieldElement(r1 - 1)
|
return fieldElement(r1 - 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
|
||||||
|
for i := range ct0 {
|
||||||
|
for j := range ct0[i] {
|
||||||
|
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
|
||||||
|
rPlusZ := fieldSub(w, cs2)
|
||||||
|
r := fieldAdd(rPlusZ, ct0)
|
||||||
|
|
||||||
|
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPlusZ, gamma2)))))
|
||||||
|
}
|
||||||
|
@ -40,8 +40,8 @@ func fieldSub(a, b fieldElement) fieldElement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qInv = 58728449
|
qInv = 58728449 // q^-1 satisfies: q^-1 * q = 1 mod 2^32
|
||||||
qNegInv = 4236238847
|
qNegInv = 4236238847 // inverse of -q modulo 2^32
|
||||||
r = 4193792 // 2^32 mod q
|
r = 4193792 // 2^32 mod q
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ type nttElement [n]fieldElement
|
|||||||
// As this implementation uses montgomery form with a multiplier of 2^32.
|
// As this implementation uses montgomery form with a multiplier of 2^32.
|
||||||
// The values need to be transformed i.e.
|
// The values need to be transformed i.e.
|
||||||
//
|
//
|
||||||
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q)))
|
// zetasMontgomery[k] = fieldReduce(zeta[k] * (2^32 * 2^32 mod(q))) = (zeta[k] * r^2) mod q
|
||||||
var zetasMontgomery = [n]fieldElement{
|
var zetasMontgomery = [n]fieldElement{
|
||||||
4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
|
4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
|
||||||
1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
|
1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
|
||||||
@ -134,7 +134,7 @@ func ntt(f ringElement) nttElement {
|
|||||||
// len: 128, 64, 32, ..., 1
|
// len: 128, 64, 32, ..., 1
|
||||||
for len := 128; len >= 1; len /= 2 {
|
for len := 128; len >= 1; len /= 2 {
|
||||||
// start
|
// start
|
||||||
for start := 0; start < 256; start += 2 * len {
|
for start := 0; start < n; start += 2 * len {
|
||||||
zeta := zetasMontgomery[k]
|
zeta := zetasMontgomery[k]
|
||||||
k++
|
k++
|
||||||
// Bounds check elimination hint.
|
// Bounds check elimination hint.
|
||||||
@ -154,8 +154,8 @@ func ntt(f ringElement) nttElement {
|
|||||||
// It implements NTT⁻¹, according to FIPS 204, Algorithm 42.
|
// It implements NTT⁻¹, according to FIPS 204, Algorithm 42.
|
||||||
func inverseNTT(f nttElement) ringElement {
|
func inverseNTT(f nttElement) ringElement {
|
||||||
k := 255
|
k := 255
|
||||||
for len := 1; len < 256; len *= 2 {
|
for len := 1; len < n; len *= 2 {
|
||||||
for start := 0; start < 256; start += 2 * len {
|
for start := 0; start < n; start += 2 * len {
|
||||||
zeta := q - zetasMontgomery[k]
|
zeta := q - zetasMontgomery[k]
|
||||||
k--
|
k--
|
||||||
// Bounds check elimination hint.
|
// Bounds check elimination hint.
|
||||||
@ -240,17 +240,3 @@ func vectorCountOnes(a []ringElement) int {
|
|||||||
return oneCount
|
return oneCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
|
|
||||||
for i := range ct0 {
|
|
||||||
for j := range ct0[i] {
|
|
||||||
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeHint(ct0, cs2, w fieldElement, gamma2 uint32) fieldElement {
|
|
||||||
rPulusZ := fieldSub(w, cs2)
|
|
||||||
r := fieldAdd(rPulusZ, ct0)
|
|
||||||
|
|
||||||
return fieldElement(1 ^ uint32(subtle.ConstantTimeEq(int32(compressHighBits(r, gamma2)), int32(compressHighBits(rPulusZ, gamma2)))))
|
|
||||||
}
|
|
||||||
|
@ -105,10 +105,10 @@ func inverseBarrettNTT(f nttElement) ringElement {
|
|||||||
return ringElement(f)
|
return ringElement(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
//func nttBarrettMul(f, g nttElement) nttElement {
|
func nttBarrettMul(f, g nttElement) nttElement {
|
||||||
// var ret nttElement
|
var ret nttElement
|
||||||
// for i, v := range f {
|
for i, v := range f {
|
||||||
// ret[i] = fieldBarrettMul(v, g[i])
|
ret[i] = fieldBarrettMul(v, g[i])
|
||||||
// }
|
}
|
||||||
// return ret
|
return ret
|
||||||
//}
|
}
|
||||||
|
@ -8,10 +8,74 @@ package mldsa
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
mathrand "math/rand/v2"
|
mathrand "math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func bitreverse(x byte) byte {
|
||||||
|
var y byte
|
||||||
|
for i := range 8 {
|
||||||
|
y |= (x & 1) << (7 - i)
|
||||||
|
x >>= 1
|
||||||
|
}
|
||||||
|
return y
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstants(t *testing.T) {
|
||||||
|
q1 := big.NewInt(q)
|
||||||
|
a := big.NewInt(1 << 32)
|
||||||
|
|
||||||
|
q1Inv := new(big.Int)
|
||||||
|
q1Inv.ModInverse(q1, a)
|
||||||
|
if q1Inv.Cmp(big.NewInt(int64(qInv))) != 0 {
|
||||||
|
t.Fatalf("q^-1 mod 2^32 = %d, expected %d", q1, qInv)
|
||||||
|
}
|
||||||
|
|
||||||
|
q1Neg := new(big.Int)
|
||||||
|
q1Neg.Sub(a, q1)
|
||||||
|
q1NegInv := new(big.Int)
|
||||||
|
q1NegInv.ModInverse(q1Neg, a)
|
||||||
|
if q1NegInv.Cmp(big.NewInt(int64(qNegInv))) != 0 {
|
||||||
|
t.Fatalf("-q^-1 mod 2^32 = %d, expected %d", q1Neg, qNegInv)
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := new(big.Int)
|
||||||
|
r1.Mod(a, q1)
|
||||||
|
if r1.Cmp(big.NewInt(int64(r))) != 0 {
|
||||||
|
t.Fatalf("r = 2^32 mod q = %d, expected %d", r1, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
dgreeInv := new(big.Int)
|
||||||
|
dgreeInv.ModInverse(big.NewInt(int64(256)), q1)
|
||||||
|
dgreeInv.Mul(dgreeInv, r1)
|
||||||
|
dgreeInv.Mul(dgreeInv, r1)
|
||||||
|
dgreeInv.Mod(dgreeInv, q1)
|
||||||
|
if dgreeInv.Int64() != 41978 {
|
||||||
|
t.Fatalf("dgreeInv = ((256^(-1) mod q) * r^2) mod q = %d, expected 41978", dgreeInv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test zetas
|
||||||
|
zeta := big.NewInt(1753)
|
||||||
|
for i := 1; i < 256; i++ {
|
||||||
|
bitRev := bitreverse(byte(i))
|
||||||
|
zetaV := new(big.Int).Exp(zeta, big.NewInt(int64(bitRev)), q1)
|
||||||
|
if uint32(zetaV.Int64()) != uint32(zetas[i]) {
|
||||||
|
t.Fatalf("zetas[%d] = %d, expected %d", i, uint32(zetaV.Int64()), zetas[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test zetasMontgomery
|
||||||
|
for i, z := range zetas {
|
||||||
|
zMont := big.NewInt(int64(z))
|
||||||
|
zMont.Mul(zMont, r1)
|
||||||
|
zMont.Mod(zMont, q1)
|
||||||
|
if zMont.Cmp(big.NewInt(int64(zetasMontgomery[i]))) != 0 {
|
||||||
|
t.Fatalf("zetasMontgomery[%d] = %d, expected %d", i, zMont, zetasMontgomery[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFieldAdd(t *testing.T) {
|
func TestFieldAdd(t *testing.T) {
|
||||||
for a := fieldElement(q - 1000); a < q; a++ {
|
for a := fieldElement(q - 1000); a < q; a++ {
|
||||||
for b := fieldElement(q - 1000); b < q; b++ {
|
for b := fieldElement(q - 1000); b < q; b++ {
|
||||||
@ -108,6 +172,37 @@ func TestInverseBarrettNTT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this is the real use case for NTT:
|
||||||
|
//
|
||||||
|
// - convert to NTT
|
||||||
|
// - multiply in NTT
|
||||||
|
// - inverse NTT
|
||||||
|
func TestInverseNTTWithMultiply(t *testing.T) {
|
||||||
|
r1 := randomRingElement()
|
||||||
|
r2 := randomRingElement()
|
||||||
|
|
||||||
|
// Montgomery Method
|
||||||
|
r11 := r1
|
||||||
|
r111 := ntt(r11)
|
||||||
|
r22 := r2
|
||||||
|
r222 := ntt(r22)
|
||||||
|
r31 := nttMul(r111, r222)
|
||||||
|
r32 := inverseNTT(r31)
|
||||||
|
|
||||||
|
// Barrett Method
|
||||||
|
b11 := barrettNTT(r1)
|
||||||
|
b22 := barrettNTT(r2)
|
||||||
|
r33 := nttBarrettMul(b11, b22)
|
||||||
|
r34 := inverseBarrettNTT(r33)
|
||||||
|
|
||||||
|
// Check if the results are equal
|
||||||
|
for i := range r32 {
|
||||||
|
if r32[i] != r34[i] {
|
||||||
|
t.Errorf("expected %v, got %v", r34[i], r32[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInfinityNorm(t *testing.T) {
|
func TestInfinityNorm(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input fieldElement
|
input fieldElement
|
||||||
|
@ -25,7 +25,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// ML-DSA global constants.
|
// ML-DSA global constants.
|
||||||
n = 256
|
n = 256 // # of coefficients in the polynomials
|
||||||
q = 8380417 // 2^23 - 2^13 + 1
|
q = 8380417 // 2^23 - 2^13 + 1
|
||||||
qMinus1Div2 = (q - 1) / 2
|
qMinus1Div2 = (q - 1) / 2
|
||||||
d = 13 // # of dropped bits from t
|
d = 13 // # of dropped bits from t
|
||||||
|
Loading…
x
Reference in New Issue
Block a user