diff --git a/mldsa/compress.go b/mldsa/compress.go index 5f45caf..8c555d6 100644 --- a/mldsa/compress.go +++ b/mldsa/compress.go @@ -99,8 +99,11 @@ func useHint(h, r fieldElement, gamma2 uint32) fieldElement { } func vectorMakeHint(ct0, cs2, w, hint []ringElement, gamma2 uint32) { + _ = hint[len(ct0)-1] // Bounds check elimination hint. + _ = cs2[len(ct0)-1] // Bounds check elimination hint. + _ = w[len(ct0)-1] // Bounds check elimination hint. for i := range ct0 { - for j := range ct0[i] { + for j := range n { hint[i][j] = makeHint(ct0[i][j], cs2[i][j], w[i][j], gamma2) } } diff --git a/mldsa/compress_test.go b/mldsa/compress_test.go index a60a5fe..ea92bc0 100644 --- a/mldsa/compress_test.go +++ b/mldsa/compress_test.go @@ -44,20 +44,20 @@ func TestPower2Round(t *testing.T) { r1, r0 := power2Round(fieldElement(i)) expectedR1, expectedR0 := _power2Round(uint32(i)) if r1 != fieldElement(expectedR1) { - t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1) + t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1) } if r0 != fieldElement(expectedR0) { - t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0) + t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0) } } for i := q - 1001; i < q; i++ { r1, r0 := power2Round(fieldElement(i)) expectedR1, expectedR0 := _power2Round(uint32(i)) if r1 != fieldElement(expectedR1) { - t.Errorf("power2Round(%d) = %d, want %d", i, r1, expectedR1) + t.Errorf("power2Round(%d) r1= %d, want %d", i, r1, expectedR1) } if r0 != fieldElement(expectedR0) { - t.Errorf("power2Round(%d) = %d, want %d", i, r0, expectedR0) + t.Errorf("power2Round(%d) r0= %d, want %d", i, r0, expectedR0) } } } diff --git a/mldsa/encoder.go b/mldsa/encoder.go index 938c56f..ce8cf82 100644 --- a/mldsa/encoder.go +++ b/mldsa/encoder.go @@ -30,11 +30,11 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte { x |= uint64(f[i+1]) << 10 x |= uint64(f[i+2]) << 20 x |= uint64(f[i+3]) << 30 + b[4] = uint8(x >> 32) b[0] = uint8(x) b[1] = uint8(x >> 8) b[2] = uint8(x >> 16) b[3] = uint8(x >> 24) - b[4] = uint8(x >> 32) b = b[5:] } return s @@ -45,7 +45,7 @@ func simpleBitPack10Bits(s []byte, f ringElement) []byte { func simpleBitUnpack10Bits(b []byte, f *ringElement) { const mask = 0x3FF for i := 0; i < n; i += 4 { - x := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) + x := (uint64(b[4]) << 32) | uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) b = b[5:] f[i] = fieldElement(x & mask) f[i+1] = fieldElement((x >> 10) & mask) @@ -80,9 +80,9 @@ func simpleBitPack4Bits(s []byte, f ringElement) []byte { // i.e. Use 6 bits from each coefficient and pack them into bytes // So every 4 coefficients fit into 3 bytes. // -// |c0||c1||c2||c3| -// | /| /\ / -// |6 2|4 4|2 6| +// |c0||c1||c2||c3| +// | /| /\ / +// |6 2|4 4|2 6| // // This is used to encode w1 when signing with ML-DSA-44 func simpleBitPack6Bits(s []byte, f ringElement) []byte { @@ -93,9 +93,9 @@ func simpleBitPack6Bits(s []byte, f ringElement) []byte { x |= uint64(f[i+1]) << 6 x |= uint64(f[i+2]) << 12 x |= uint64(f[i+3]) << 18 + b[2] = uint8(x >> 16) b[0] = uint8(x) b[1] = uint8(x >> 8) - b[2] = uint8(x >> 16) b = b[3:] } @@ -126,9 +126,10 @@ func bitPackSigned2(s []byte, f ringElement) []byte { x |= uint32(fieldSub(2, f[i+5])) << 15 x |= uint32(fieldSub(2, f[i+6])) << 18 x |= uint32(fieldSub(2, f[i+7])) << 21 + b[2] = uint8(x >> 16) b[0] = uint8(x) b[1] = uint8(x >> 8) - b[2] = uint8(x >> 16) + b = b[3:] } return s @@ -140,7 +141,7 @@ func bitUnpackSigned2(b []byte) (ringElement, error) { const bitsMask = 0x7 var f ringElement for i := 0; i < n; i += 8 { - x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) + x := (uint32(b[2]) << 16) | uint32(b[0]) | (uint32(b[1]) << 8) msbs := x & 0o44444444 mask := (msbs >> 1) | (msbs >> 2) if subtle.ConstantTimeEq(int32(mask&x), 0) == 0 { @@ -160,7 +161,6 @@ func bitUnpackSigned2(b []byte) (ringElement, error) { return f, nil } - // bitPackSigned4 encodes a polynomial into a byte string, assuming that all // coefficients are in the range -4..4. // See FIPS 204, Algorithm 17, BitPack(w, a, b). (a = 4, b = 4) @@ -185,7 +185,7 @@ func bitUnpackSigned4(b []byte) (ringElement, error) { const bitsMask = 0xF var f ringElement for i := 0; i < n; i += 8 { - x := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) + x := (uint32(b[3]) << 24) | uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) // None of the nibbles may be >= 9. So if the MSB of any nibble is set, // none of the other bits may be set. First, select all the MSBs. msbs := x & 0x88888888 @@ -208,7 +208,6 @@ func bitUnpackSigned4(b []byte) (ringElement, error) { return f, nil } - // bitPackSigned4196 encodes a polynomial f into a byte slice, assuming that all // coefficients are in the range (-2^12 + 1)..2^12. // See FIPS 204, Algorithm 17, BitPack(w, a, b). where a = 2^12 - 1, b = 2^12. @@ -237,6 +236,8 @@ func bitPackSigned4096(s []byte, f ringElement) []byte { x2 |= uint64(fieldSub(r, f[i+5])) << 1 x2 |= uint64(fieldSub(r, f[i+6])) << 14 x2 |= uint64(fieldSub(r, f[i+7])) << 27 + + b[12] = uint8(x2 >> 32) b[0] = uint8(x1) b[1] = uint8(x1 >> 8) b[2] = uint8(x1 >> 16) @@ -249,7 +250,6 @@ func bitPackSigned4096(s []byte, f ringElement) []byte { b[9] = uint8(x2 >> 8) b[10] = uint8(x2 >> 16) b[11] = uint8(x2 >> 24) - b[12] = uint8(x2 >> 32) b = b[13:] } @@ -262,8 +262,8 @@ func bitUnpackSigned4096(b []byte, f *ringElement) error { const bitsMask = 0x1FFF // 2^13-1 const r = 4096 // 2^12 for i := 0; i < n; i += 8 { + x2 := (uint64(b[12]) << 32) | uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24) x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) - x2 := uint64(b[8]) | (uint64(b[9]) << 8) | (uint64(b[10]) << 16) | (uint64(b[11]) << 24) | (uint64(b[12]) << 32) b = b[13:] f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>13)&bitsMask)) @@ -302,6 +302,8 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte { x2 = uint64(fieldSub(r, f[i+3])) x1 |= x2 << 54 x2 >>= 10 + + b[8] = uint8(x2) b[0] = uint8(x1) b[1] = uint8(x1 >> 8) b[2] = uint8(x1 >> 16) @@ -310,7 +312,6 @@ func bitPackSignedTwoPower17(s []byte, f ringElement) []byte { b[5] = uint8(x1 >> 40) b[6] = uint8(x1 >> 48) b[7] = uint8(x1 >> 56) - b[8] = uint8(x2) b = b[9:] } @@ -323,9 +324,10 @@ func bitUnpackSignedTwoPower17(b []byte, f *ringElement) { const bitsMask = 0x3FFFF // 2^18-1 const r = 131072 // 2^17 for i := 0; i < n; i += 4 { - x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) x2 := uint64(b[8]) + x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) b = b[9:] + f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>18)&bitsMask)) f[i+2] = fieldSub(r, fieldElement((x1>>36)&bitsMask)) @@ -358,6 +360,9 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte { x2 = uint64(fieldSub(r, f[i+3])) x1 |= x2 << 60 x2 >>= 4 + + b[9] = uint8(x2 >> 8) + b[8] = uint8(x2) b[0] = uint8(x1) b[1] = uint8(x1 >> 8) b[2] = uint8(x1 >> 16) @@ -366,9 +371,7 @@ func bitPackSignedTwoPower19(s []byte, f ringElement) []byte { b[5] = uint8(x1 >> 40) b[6] = uint8(x1 >> 48) b[7] = uint8(x1 >> 56) - b[8] = uint8(x2) - b[9] = uint8(x2 >> 8) - + b = b[10:] } return s @@ -382,8 +385,9 @@ func bitUnpackSignedTwoPower19(b []byte, f *ringElement) { const bitsMask = 0xFFFFF // 2^20-1 const r = 524288 // 2^19 for i := 0; i < n; i += 4 { + x2 := (uint64(b[9]) << 8) | uint64(b[8]) x1 := uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56) - x2 := uint64(b[8]) | (uint64(b[9]) << 8) + b = b[10:] f[i] = fieldSub(r, fieldElement(x1&bitsMask)) f[i+1] = fieldSub(r, fieldElement((x1>>20)&bitsMask)) @@ -398,7 +402,7 @@ func hintBitPack(s []byte, hint []ringElement, omega int) []byte { s, b := alias.SliceForAppend(s, omega+k) index := 0 for i := range k { - for j := 0; j < n; j++ { + for j := range n { if hint[i][j] != 0 { b[index] = byte(j) index++ @@ -428,7 +432,8 @@ func hintBitUnpack(b []byte, hint []ringElement, omega int) bool { hint[i][bi] = 1 } } - for i := index; i < omega; i++ { + b = b[index:omega] + for i := range b { if b[i] != 0 { return false } diff --git a/mldsa/field.go b/mldsa/field.go index aafbafb..4139c98 100644 --- a/mldsa/field.go +++ b/mldsa/field.go @@ -45,6 +45,7 @@ const ( r = 4193792 // 2^32 mod q ) +// See FIPS 204, Algorithm 49, MontgomeryReduce() func fieldReduce(a uint64) fieldElement { t := uint32(a) * qNegInv return fieldReduceOnce(uint32((a + uint64(t)*q) >> 32)) diff --git a/mldsa/sample.go b/mldsa/sample.go index 5a4e05b..e7c8de8 100644 --- a/mldsa/sample.go +++ b/mldsa/sample.go @@ -32,7 +32,7 @@ func rejNTTPoly(rho []byte, s, r byte) nttElement { a[j] = fieldElement(d) j++ } - if j >= len(a) { + if j >= n { return a } } @@ -73,14 +73,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement { if subtle.ConstantTimeByteEq(z0, 15) == 0 { a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z0)))) j++ - if j >= len(a) { + if j >= n { break } } if subtle.ConstantTimeByteEq(z1, 15) == 0 { a[j] = fieldSub(2, fieldElement(constantMod5(uint32(z1)))) j++ - if j >= len(a) { + if j >= n { break } } @@ -88,14 +88,14 @@ func rejBoundedPoly(rho []byte, eta int, highByte, lowByte byte) ringElement { if subtle.ConstantTimeLessOrEq(int(z0), 8) == 1 { a[j] = fieldSub(4, fieldElement(z0)) j++ - if j >= len(a) { + if j >= n { break } } if subtle.ConstantTimeLessOrEq(int(z1), 8) == 1 { a[j] = fieldSub(4, fieldElement(z1)) j++ - if j >= len(a) { + if j >= n { break } }