mldsa: bounds check elimination

This commit is contained in:
Sun Yimin 2025-06-04 11:01:59 +08:00 committed by GitHub
parent 5084ea06e3
commit 0ec4ddf58f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 31 deletions

View File

@ -99,8 +99,11 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement {
}
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
_ = hint[len(ct0)-1] // Bounds check elimination hint.
_ = cs2[len(ct0)-1] // Bounds check elimination hint.
_ = w[len(ct0)-1] // Bounds check elimination hint.
for i := range ct0 {
for j := range ct0[i] {
for j := range n {
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
}
}

View File

@ -44,20 +44,20 @@ func TestPower2Round(t *testing.T) {
r1, r0 := power2Round(fieldElement(i))
expectedR1, expectedR0 := _power2Round(uint32(i))
if r1 != fieldElement(expectedR1) {
t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1)
t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1)
}
if r0 != fieldElement(expectedR0) {
t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0)
t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0)
}
}
for i := q - 1001; i < q; i++ {
r1, r0 := power2Round(fieldElement(i))
expectedR1, expectedR0 := _power2Round(uint32(i))
if r1 != fieldElement(expectedR1) {
t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1)
t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1)
}
if r0 != fieldElement(expectedR0) {
t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0)
t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0)
}
}
}

View File

@ -30,11 +30,11 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte {
x |= uint64(f[i+1]) << 10
x |= uint64(f[i+2]) << 20
x |= uint64(f[i+3]) << 30
b[4] = uint8(x >> 32)
b[0] = uint8(x)
b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16)
b[3] = uint8(x >> 24)
b[4] = uint8(x >> 32)
b = b[5:]
}
return s
@ -45,7 +45,7 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte {
func simpleBitUnpack10Bits(b []byte, f *ringElement) {
const mask = 0x3FF
for i := 0; i < n; i += 4 {
x := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32)
x := (uint64(b[4]) << 32) | uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
b = b[5:]
f[i] = fieldElement(x & mask)
f[i+1] = fieldElement((x >> 10) & mask)
@ -80,9 +80,9 @@ func simpleBitPack4Bits(s []byte, f ringElement) []byte {
// i.e. Use 6 bits from each coefficient and pack them into bytes
// So every 4 coefficients fit into 3 bytes.
//
// |c0||c1||c2||c3|
// | /| /\ /
// |6 2|4 4|2 6|
// |c0||c1||c2||c3|
// | /| /\ /
// |6 2|4 4|2 6|
//
// This is used to encode w1 when signing with ML-DSA-44
func simpleBitPack6Bits(s []byte, f ringElement) []byte {
@ -93,9 +93,9 @@ func simpleBitPack6Bits(s []byte, f ringElement) []byte {
x |= uint64(f[i+1]) << 6
x |= uint64(f[i+2]) << 12
x |= uint64(f[i+3]) << 18
b[2] = uint8(x >> 16)
b[0] = uint8(x)
b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16)
b = b[3:]
}
@ -126,9 +126,10 @@ func bitPackSigned2(s []byte, f ringElement) []byte {
x |= uint32(fieldSub(2, f[i+5])) << 15
x |= uint32(fieldSub(2, f[i+6])) << 18
x |= uint32(fieldSub(2, f[i+7])) << 21
b[2] = uint8(x >> 16)
b[0] = uint8(x)
b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16)
b = b[3:]
}
return s
@ -140,7 +141,7 @@ func bitUnpackSigned2(b []byte) (ringElement, error) {
const bitsMask = 0x7
var f ringElement
for i := 0; i < n; i += 8 {
x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16)
x := (uint32(b[2]) << 16) | uint32(b[0]) | (uint32(b[1]) << 8)
msbs := x & 0o44444444
mask := (msbs >> 1) | (msbs >> 2)
if subtle.ConstantTimeEq(int32(mask&x), 0) == 0 {
@ -160,7 +161,6 @@ func bitUnpackSigned2(b []byte) (ringElement, error) {
return f, nil
}
// bitPackSigned4 encodes a polynomial into a byte string, assuming that all
// coefficients are in the range -4..4.
// See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4)
@ -185,7 +185,7 @@ func bitUnpackSigned4(b []byte) (ringElement, error) {
const bitsMask = 0xF
var f ringElement
for i := 0; i < n; i += 8 {
x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
x := (uint32(b[3]) << 24) | uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16)
// None of the nibbles may be >= 9. So if the MSB of any nibble is set,
// none of the other bits may be set. First, select all the MSBs.
msbs := x & 0x88888888
@ -208,7 +208,6 @@ func bitUnpackSigned4(b []byte) (ringElement, error) {
return f, nil
}
// bitPackSigned4196 encodes a polynomial f into a byte slice, assuming that all
// coefficients are in the range (-2^12 + 1)..2^12.
// See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12.
@ -237,6 +236,8 @@ func bitPackSigned4096(s []byte, f ringElement) []byte {
x2 |= uint64(fieldSub(r, f[i+5])) << 1
x2 |= uint64(fieldSub(r, f[i+6])) << 14
x2 |= uint64(fieldSub(r, f[i+7])) << 27
b[12] = uint8(x2 >> 32)
b[0] = uint8(x1)
b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16)
@ -249,7 +250,6 @@ func bitPackSigned4096(s []byte, f ringElement) []byte {
b[9] = uint8(x2 >> 8)
b[10] = uint8(x2 >> 16)
b[11] = uint8(x2 >> 24)
b[12] = uint8(x2 >> 32)
b = b[13:]
}
@ -262,8 +262,8 @@ func bitUnpackSigned4096(b []byte, f *ringElement) error {
const bitsMask = 0x1FFF // 2^13-1
const r = 4096 // 2^12
for i := 0; i < n; i += 8 {
x2 := (uint64(b[12]) << 32) | uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24)
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24) | (uint64(b[12]) << 32)
b = b[13:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>13)&bitsMask))
@ -302,6 +302,8 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte {
x2 = uint64(fieldSub(r, f[i+3]))
x1 |= x2 << 54
x2 >>= 10
b[8] = uint8(x2)
b[0] = uint8(x1)
b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16)
@ -310,7 +312,6 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte {
b[5] = uint8(x1 >> 40)
b[6] = uint8(x1 >> 48)
b[7] = uint8(x1 >> 56)
b[8] = uint8(x2)
b = b[9:]
}
@ -323,9 +324,10 @@ func bitUnpackSignedTwoPower17(b []byte, f *ringElement) {
const bitsMask = 0x3FFFF // 2^18-1
const r = 131072 // 2^17
for i := 0; i < n; i += 4 {
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8])
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
b = b[9:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>18)&bitsMask))
f[i+2] = fieldSub(r, fieldElement((x1>>36)&bitsMask))
@ -358,6 +360,9 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte {
x2 = uint64(fieldSub(r, f[i+3]))
x1 |= x2 << 60
x2 >>= 4
b[9] = uint8(x2 >> 8)
b[8] = uint8(x2)
b[0] = uint8(x1)
b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16)
@ -366,9 +371,7 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte {
b[5] = uint8(x1 >> 40)
b[6] = uint8(x1 >> 48)
b[7] = uint8(x1 >> 56)
b[8] = uint8(x2)
b[9] = uint8(x2 >> 8)
b = b[10:]
}
return s
@ -382,8 +385,9 @@ func bitUnpackSignedTwoPower19(b []byte, f *ringElement) {
const bitsMask = 0xFFFFF // 2^20-1
const r = 524288 // 2^19
for i := 0; i < n; i += 4 {
x2 := (uint64(b[9]) << 8) | uint64(b[8])
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8]) | (uint64(b[9]) << 8)
b = b[10:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>20)&bitsMask))
@ -398,7 +402,7 @@ func hintBitPack(s []byte, hint []ringElement, omega int) []byte {
s, b := alias.SliceForAppend(s, omega+k)
index := 0
for i := range k {
for j := 0; j < n; j++ {
for j := range n {
if hint[i][j] != 0 {
b[index] = byte(j)
index++
@ -428,7 +432,8 @@ func hintBitUnpack(b []byte, hint []ringElement, omega int) bool {
hint[i][bi] = 1
}
}
for i := index; i < omega; i++ {
b = b[index:omega]
for i := range b {
if b[i] != 0 {
return false
}

View File

@ -45,6 +45,7 @@ const (
r = 4193792 // 2^32 mod q
)
// See FIPS 204, Algorithm 49, MontgomeryReduce()
func fieldReduce(a uint64) fieldElement {
t := uint32(a) * qNegInv
return fieldReduceOnce(uint32((a + uint64(t)*q) >> 32))

View File

@ -32,7 +32,7 @@ func rejNTTPoly(rho []byte, s, r byte) nttElement {
a[j] = fieldElement(d)
j++
}
if j >= len(a) {
if j >= n {
return a
}
}
@ -73,14 +73,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement {
if subtle.ConstantTimeByteEq(z0, 15) == 0 {
a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z0))))
j++
if j >= len(a) {
if j >= n {
break
}
}
if subtle.ConstantTimeByteEq(z1, 15) == 0 {
a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z1))))
j++
if j >= len(a) {
if j >= n {
break
}
}
@ -88,14 +88,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement {
if subtle.ConstantTimeLessOrEq(int(z0), 8) == 1 {
a[j] = fieldSub(4, fieldElement(z0))
j++
if j >= len(a) {
if j >= n {
break
}
}
if subtle.ConstantTimeLessOrEq(int(z1), 8) == 1 {
a[j] = fieldSub(4, fieldElement(z1))
j++
if j >= len(a) {
if j >= n {
break
}
}