mldsa: bounds check elimination

This commit is contained in:
Sun Yimin 2025-06-04 11:01:59 +08:00 committed by GitHub
parent 5084ea06e3
commit 0ec4ddf58f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 31 deletions

View File

@ -99,8 +99,11 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement {
} }
func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) { func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) {
_ = hint[len(ct0)-1] // Bounds check elimination hint.
_ = cs2[len(ct0)-1] // Bounds check elimination hint.
_ = w[len(ct0)-1] // Bounds check elimination hint.
for i := range ct0 { for i := range ct0 {
for j := range ct0[i] { for j := range n {
hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2) hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2)
} }
} }

View File

@ -44,20 +44,20 @@ func TestPower2Round(t *testing.T) {
r1, r0 := power2Round(fieldElement(i)) r1, r0 := power2Round(fieldElement(i))
expectedR1, expectedR0 := _power2Round(uint32(i)) expectedR1, expectedR0 := _power2Round(uint32(i))
if r1 != fieldElement(expectedR1) { if r1 != fieldElement(expectedR1) {
t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1) t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1)
} }
if r0 != fieldElement(expectedR0) { if r0 != fieldElement(expectedR0) {
t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0) t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0)
} }
} }
for i := q - 1001; i < q; i++ { for i := q - 1001; i < q; i++ {
r1, r0 := power2Round(fieldElement(i)) r1, r0 := power2Round(fieldElement(i))
expectedR1, expectedR0 := _power2Round(uint32(i)) expectedR1, expectedR0 := _power2Round(uint32(i))
if r1 != fieldElement(expectedR1) { if r1 != fieldElement(expectedR1) {
t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1) t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1)
} }
if r0 != fieldElement(expectedR0) { if r0 != fieldElement(expectedR0) {
t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0) t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0)
} }
} }
} }

View File

@ -30,11 +30,11 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte {
x |= uint64(f[i+1]) << 10 x |= uint64(f[i+1]) << 10
x |= uint64(f[i+2]) << 20 x |= uint64(f[i+2]) << 20
x |= uint64(f[i+3]) << 30 x |= uint64(f[i+3]) << 30
b[4] = uint8(x >> 32)
b[0] = uint8(x) b[0] = uint8(x)
b[1] = uint8(x >> 8) b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16) b[2] = uint8(x >> 16)
b[3] = uint8(x >> 24) b[3] = uint8(x >> 24)
b[4] = uint8(x >> 32)
b = b[5:] b = b[5:]
} }
return s return s
@ -45,7 +45,7 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte {
func simpleBitUnpack10Bits(b []byte, f *ringElement) { func simpleBitUnpack10Bits(b []byte, f *ringElement) {
const mask = 0x3FF const mask = 0x3FF
for i := 0; i < n; i += 4 { for i := 0; i < n; i += 4 {
x := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) x := (uint64(b[4]) << 32) | uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
b = b[5:] b = b[5:]
f[i] = fieldElement(x & mask) f[i] = fieldElement(x & mask)
f[i+1] = fieldElement((x >> 10) & mask) f[i+1] = fieldElement((x >> 10) & mask)
@ -80,9 +80,9 @@ func simpleBitPack4Bits(s []byte, f ringElement) []byte {
// i.e. Use 6 bits from each coefficient and pack them into bytes // i.e. Use 6 bits from each coefficient and pack them into bytes
// So every 4 coefficients fit into 3 bytes. // So every 4 coefficients fit into 3 bytes.
// //
// |c0||c1||c2||c3| // |c0||c1||c2||c3|
// | /| /\ / // | /| /\ /
// |6 2|4 4|2 6| // |6 2|4 4|2 6|
// //
// This is used to encode w1 when signing with ML-DSA-44 // This is used to encode w1 when signing with ML-DSA-44
func simpleBitPack6Bits(s []byte, f ringElement) []byte { func simpleBitPack6Bits(s []byte, f ringElement) []byte {
@ -93,9 +93,9 @@ func simpleBitPack6Bits(s []byte, f ringElement) []byte {
x |= uint64(f[i+1]) << 6 x |= uint64(f[i+1]) << 6
x |= uint64(f[i+2]) << 12 x |= uint64(f[i+2]) << 12
x |= uint64(f[i+3]) << 18 x |= uint64(f[i+3]) << 18
b[2] = uint8(x >> 16)
b[0] = uint8(x) b[0] = uint8(x)
b[1] = uint8(x >> 8) b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16)
b = b[3:] b = b[3:]
} }
@ -126,9 +126,10 @@ func bitPackSigned2(s []byte, f ringElement) []byte {
x |= uint32(fieldSub(2, f[i+5])) << 15 x |= uint32(fieldSub(2, f[i+5])) << 15
x |= uint32(fieldSub(2, f[i+6])) << 18 x |= uint32(fieldSub(2, f[i+6])) << 18
x |= uint32(fieldSub(2, f[i+7])) << 21 x |= uint32(fieldSub(2, f[i+7])) << 21
b[2] = uint8(x >> 16)
b[0] = uint8(x) b[0] = uint8(x)
b[1] = uint8(x >> 8) b[1] = uint8(x >> 8)
b[2] = uint8(x >> 16)
b = b[3:] b = b[3:]
} }
return s return s
@ -140,7 +141,7 @@ func bitUnpackSigned2(b []byte) (ringElement, error) {
const bitsMask = 0x7 const bitsMask = 0x7
var f ringElement var f ringElement
for i := 0; i < n; i += 8 { for i := 0; i < n; i += 8 {
x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) x := (uint32(b[2]) << 16) | uint32(b[0]) | (uint32(b[1]) << 8)
msbs := x & 0o44444444 msbs := x & 0o44444444
mask := (msbs >> 1) | (msbs >> 2) mask := (msbs >> 1) | (msbs >> 2)
if subtle.ConstantTimeEq(int32(mask&x), 0) == 0 { if subtle.ConstantTimeEq(int32(mask&x), 0) == 0 {
@ -160,7 +161,6 @@ func bitUnpackSigned2(b []byte) (ringElement, error) {
return f, nil return f, nil
} }
// bitPackSigned4 encodes a polynomial into a byte string, assuming that all // bitPackSigned4 encodes a polynomial into a byte string, assuming that all
// coefficients are in the range -4..4. // coefficients are in the range -4..4.
// See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4) // See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4)
@ -185,7 +185,7 @@ func bitUnpackSigned4(b []byte) (ringElement, error) {
const bitsMask = 0xF const bitsMask = 0xF
var f ringElement var f ringElement
for i := 0; i < n; i += 8 { for i := 0; i < n; i += 8 {
x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) x := (uint32(b[3]) << 24) | uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16)
// None of the nibbles may be >= 9. So if the MSB of any nibble is set, // None of the nibbles may be >= 9. So if the MSB of any nibble is set,
// none of the other bits may be set. First, select all the MSBs. // none of the other bits may be set. First, select all the MSBs.
msbs := x & 0x88888888 msbs := x & 0x88888888
@ -208,7 +208,6 @@ func bitUnpackSigned4(b []byte) (ringElement, error) {
return f, nil return f, nil
} }
// bitPackSigned4196 encodes a polynomial f into a byte slice, assuming that all // bitPackSigned4196 encodes a polynomial f into a byte slice, assuming that all
// coefficients are in the range (-2^12 + 1)..2^12. // coefficients are in the range (-2^12 + 1)..2^12.
// See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12. // See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12.
@ -237,6 +236,8 @@ func bitPackSigned4096(s []byte, f ringElement) []byte {
x2 |= uint64(fieldSub(r, f[i+5])) << 1 x2 |= uint64(fieldSub(r, f[i+5])) << 1
x2 |= uint64(fieldSub(r, f[i+6])) << 14 x2 |= uint64(fieldSub(r, f[i+6])) << 14
x2 |= uint64(fieldSub(r, f[i+7])) << 27 x2 |= uint64(fieldSub(r, f[i+7])) << 27
b[12] = uint8(x2 >> 32)
b[0] = uint8(x1) b[0] = uint8(x1)
b[1] = uint8(x1 >> 8) b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16) b[2] = uint8(x1 >> 16)
@ -249,7 +250,6 @@ func bitPackSigned4096(s []byte, f ringElement) []byte {
b[9] = uint8(x2 >> 8) b[9] = uint8(x2 >> 8)
b[10] = uint8(x2 >> 16) b[10] = uint8(x2 >> 16)
b[11] = uint8(x2 >> 24) b[11] = uint8(x2 >> 24)
b[12] = uint8(x2 >> 32)
b = b[13:] b = b[13:]
} }
@ -262,8 +262,8 @@ func bitUnpackSigned4096(b []byte, f *ringElement) error {
const bitsMask = 0x1FFF // 2^13-1 const bitsMask = 0x1FFF // 2^13-1
const r = 4096 // 2^12 const r = 4096 // 2^12
for i := 0; i < n; i += 8 { for i := 0; i < n; i += 8 {
x2 := (uint64(b[12]) << 32) | uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24)
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24) | (uint64(b[12]) << 32)
b = b[13:] b = b[13:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>13)&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>13)&bitsMask))
@ -302,6 +302,8 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte {
x2 = uint64(fieldSub(r, f[i+3])) x2 = uint64(fieldSub(r, f[i+3]))
x1 |= x2 << 54 x1 |= x2 << 54
x2 >>= 10 x2 >>= 10
b[8] = uint8(x2)
b[0] = uint8(x1) b[0] = uint8(x1)
b[1] = uint8(x1 >> 8) b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16) b[2] = uint8(x1 >> 16)
@ -310,7 +312,6 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte {
b[5] = uint8(x1 >> 40) b[5] = uint8(x1 >> 40)
b[6] = uint8(x1 >> 48) b[6] = uint8(x1 >> 48)
b[7] = uint8(x1 >> 56) b[7] = uint8(x1 >> 56)
b[8] = uint8(x2)
b = b[9:] b = b[9:]
} }
@ -323,9 +324,10 @@ func bitUnpackSignedTwoPower17(b []byte, f *ringElement) {
const bitsMask = 0x3FFFF // 2^18-1 const bitsMask = 0x3FFFF // 2^18-1
const r = 131072 // 2^17 const r = 131072 // 2^17
for i := 0; i < n; i += 4 { for i := 0; i < n; i += 4 {
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8]) x2 := uint64(b[8])
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
b = b[9:] b = b[9:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>18)&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>18)&bitsMask))
f[i+2] = fieldSub(r, fieldElement((x1>>36)&bitsMask)) f[i+2] = fieldSub(r, fieldElement((x1>>36)&bitsMask))
@ -358,6 +360,9 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte {
x2 = uint64(fieldSub(r, f[i+3])) x2 = uint64(fieldSub(r, f[i+3]))
x1 |= x2 << 60 x1 |= x2 << 60
x2 >>= 4 x2 >>= 4
b[9] = uint8(x2 >> 8)
b[8] = uint8(x2)
b[0] = uint8(x1) b[0] = uint8(x1)
b[1] = uint8(x1 >> 8) b[1] = uint8(x1 >> 8)
b[2] = uint8(x1 >> 16) b[2] = uint8(x1 >> 16)
@ -366,9 +371,7 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte {
b[5] = uint8(x1 >> 40) b[5] = uint8(x1 >> 40)
b[6] = uint8(x1 >> 48) b[6] = uint8(x1 >> 48)
b[7] = uint8(x1 >> 56) b[7] = uint8(x1 >> 56)
b[8] = uint8(x2)
b[9] = uint8(x2 >> 8)
b = b[10:] b = b[10:]
} }
return s return s
@ -382,8 +385,9 @@ func bitUnpackSignedTwoPower19(b []byte, f *ringElement) {
const bitsMask = 0xFFFFF // 2^20-1 const bitsMask = 0xFFFFF // 2^20-1
const r = 524288 // 2^19 const r = 524288 // 2^19
for i := 0; i < n; i += 4 { for i := 0; i < n; i += 4 {
x2 := (uint64(b[9]) << 8) | uint64(b[8])
x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)
x2 := uint64(b[8]) | (uint64(b[9]) << 8)
b = b[10:] b = b[10:]
f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i] = fieldSub(r, fieldElement(x1&bitsMask))
f[i+1] = fieldSub(r, fieldElement((x1>>20)&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>20)&bitsMask))
@ -398,7 +402,7 @@ func hintBitPack(s []byte, hint []ringElement, omega int) []byte {
s, b := alias.SliceForAppend(s, omega+k) s, b := alias.SliceForAppend(s, omega+k)
index := 0 index := 0
for i := range k { for i := range k {
for j := 0; j < n; j++ { for j := range n {
if hint[i][j] != 0 { if hint[i][j] != 0 {
b[index] = byte(j) b[index] = byte(j)
index++ index++
@ -428,7 +432,8 @@ func hintBitUnpack(b []byte, hint []ringElement, omega int) bool {
hint[i][bi] = 1 hint[i][bi] = 1
} }
} }
for i := index; i < omega; i++ { b = b[index:omega]
for i := range b {
if b[i] != 0 { if b[i] != 0 {
return false return false
} }

View File

@ -45,6 +45,7 @@ const (
r = 4193792 // 2^32 mod q r = 4193792 // 2^32 mod q
) )
// See FIPS 204, Algorithm 49, MontgomeryReduce()
func fieldReduce(a uint64) fieldElement { func fieldReduce(a uint64) fieldElement {
t := uint32(a) * qNegInv t := uint32(a) * qNegInv
return fieldReduceOnce(uint32((a + uint64(t)*q) >> 32)) return fieldReduceOnce(uint32((a + uint64(t)*q) >> 32))

View File

@ -32,7 +32,7 @@ func rejNTTPoly(rho []byte, s, r byte) nttElement {
a[j] = fieldElement(d) a[j] = fieldElement(d)
j++ j++
} }
if j >= len(a) { if j >= n {
return a return a
} }
} }
@ -73,14 +73,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement {
if subtle.ConstantTimeByteEq(z0, 15) == 0 { if subtle.ConstantTimeByteEq(z0, 15) == 0 {
a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z0)))) a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z0))))
j++ j++
if j >= len(a) { if j >= n {
break break
} }
} }
if subtle.ConstantTimeByteEq(z1, 15) == 0 { if subtle.ConstantTimeByteEq(z1, 15) == 0 {
a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z1)))) a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z1))))
j++ j++
if j >= len(a) { if j >= n {
break break
} }
} }
@ -88,14 +88,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement {
if subtle.ConstantTimeLessOrEq(int(z0), 8) == 1 { if subtle.ConstantTimeLessOrEq(int(z0), 8) == 1 {
a[j] = fieldSub(4, fieldElement(z0)) a[j] = fieldSub(4, fieldElement(z0))
j++ j++
if j >= len(a) { if j >= n {
break break
} }
} }
if subtle.ConstantTimeLessOrEq(int(z1), 8) == 1 { if subtle.ConstantTimeLessOrEq(int(z1), 8) == 1 {
a[j] = fieldSub(4, fieldElement(z1)) a[j] = fieldSub(4, fieldElement(z1))
j++ j++
if j >= len(a) { if j >= n {
break break
} }
} }