diff --git a/internal/sm2ec/p256_asm_amd64.s b/internal/sm2ec/p256_asm_amd64.s index 6966498..579893d 100644 --- a/internal/sm2ec/p256_asm_amd64.s +++ b/internal/sm2ec/p256_asm_amd64.s @@ -152,13 +152,9 @@ move_avx2: VPANDN (32*1)(x_ptr), Y12, Y1 VPANDN (32*2)(x_ptr), Y12, Y2 - VMOVDQU (32*0)(y_ptr), Y3 - VMOVDQU (32*1)(y_ptr), Y4 - VMOVDQU (32*2)(y_ptr), Y5 - - VPAND Y12, Y3, Y3 - VPAND Y12, Y4, Y4 - VPAND Y12, Y5, Y5 + VPAND (32*0)(y_ptr), Y12, Y3 + VPAND (32*1)(y_ptr), Y12, Y4 + VPAND (32*2)(y_ptr), Y12, Y5 VPXOR Y3, Y0, Y0 VPXOR Y4, Y1, Y1 @@ -963,7 +959,7 @@ TEXT ·p256FromMont(SB),NOSPLIT,$0 RET /* ---------------------------------------*/ -// func p256Select(res *SM2P256Point, table *p256Table, idx int) +// func p256Select(res *SM2P256Point, table *p256Table, idx, limit int) TEXT ·p256Select(SB),NOSPLIT,$0 //MOVQ idx+16(FP),AX MOVQ table+8(FP),DI @@ -984,7 +980,7 @@ TEXT ·p256Select(SB),NOSPLIT,$0 PXOR X3, X3 PXOR X4, X4 PXOR X5, X5 - MOVQ $16, AX + MOVQ limit+24(FP),AX MOVOU X15, X13 @@ -1035,7 +1031,7 @@ select_avx2: MOVL idx+16(FP), X14 // x14 = idx VPBROADCASTD X14, Y14 - MOVQ $16, AX + MOVQ limit+24(FP),AX VMOVDQU Y15, Y13 VPXOR Y0, Y0, Y0 @@ -1047,16 +1043,12 @@ loop_select_avx2: VPADDD Y15, Y13, Y13 VPCMPEQD Y14, Y12, Y12 - VMOVDQU (32*0)(DI), Y3 - VMOVDQU (32*1)(DI), Y4 - VMOVDQU (32*2)(DI), Y5 + VPAND (32*0)(DI), Y12, Y3 + VPAND (32*1)(DI), Y12, Y4 + VPAND (32*2)(DI), Y12, Y5 ADDQ $(32*3), DI - VPAND Y12, Y3, Y3 - VPAND Y12, Y4, Y4 - VPAND Y12, Y5, Y5 - VPXOR Y3, Y0, Y0 VPXOR Y4, Y1, Y1 VPXOR Y5, Y2, Y2 @@ -1163,22 +1155,17 @@ loop_select_base_avx2: VPADDD Y15, Y13, Y13 VPCMPEQD Y14, Y12, Y12 - VMOVDQU (32*0)(DI), Y2 - VMOVDQU (32*1)(DI), Y3 - VMOVDQU (32*2)(DI), Y4 - VMOVDQU (32*3)(DI), Y5 - - ADDQ $(32*4), DI - - VPAND Y12, Y2, Y2 - VPAND Y12, Y3, Y3 + VPAND (32*0)(DI), Y12, Y2 + VPAND (32*1)(DI), Y12, Y3 VMOVDQU Y13, Y12 VPADDD Y15, Y13, Y13 VPCMPEQD Y14, Y12, Y12 - VPAND Y12, Y4, Y4 - VPAND Y12, Y5, Y5 + VPAND (32*2)(DI), Y12, Y4 + VPAND (32*3)(DI), Y12, Y5 + + ADDQ $(32*4), DI VPXOR Y2, Y0, Y0 VPXOR Y3, Y1, Y1 @@ -3097,10 +3084,6 @@ pointaddaffine_avx2: p256PointAddAffineInline() // The result is not valid if (sel == 0), conditional choose - VMOVDQU xout(32*0), Y0 - VMOVDQU yout(32*0), Y1 - VMOVDQU zout(32*0), Y2 - MOVL BX, X6 MOVL CX, X7 @@ -3116,17 +3099,13 @@ pointaddaffine_avx2: VMOVDQU Y6, Y15 VPANDN Y9, Y15, Y15 - VMOVDQU x1in(32*0), Y9 - VMOVDQU y1in(32*0), Y10 - VMOVDQU z1in(32*0), Y11 + VPAND xout(32*0), Y15, Y0 + VPAND yout(32*0), Y15, Y1 + VPAND zout(32*0), Y15, Y2 - VPAND Y15, Y0, Y0 - VPAND Y15, Y1, Y1 - VPAND Y15, Y2, Y2 - - VPAND Y6, Y9, Y9 - VPAND Y6, Y10, Y10 - VPAND Y6, Y11, Y11 + VPAND x1in(32*0), Y6, Y9 + VPAND y1in(32*0), Y6, Y10 + VPAND z1in(32*0), Y6, Y11 VPXOR Y9, Y0, Y0 VPXOR Y10, Y1, Y1 @@ -3136,17 +3115,13 @@ pointaddaffine_avx2: VPCMPEQD Y9, Y9, Y9 VPANDN Y9, Y7, Y15 - VMOVDQU x2in(32*0), Y9 - VMOVDQU y2in(32*0), Y10 - VMOVDQU p256one<>+0x00(SB), Y11 - VPAND Y15, Y0, Y0 VPAND Y15, Y1, Y1 VPAND Y15, Y2, Y2 - VPAND Y7, Y9, Y9 - VPAND Y7, Y10, Y10 - VPAND Y7, Y11, Y11 + VPAND x2in(32*0), Y7, Y9 + VPAND y2in(32*0), Y7, Y10 + VPAND p256one<>+0x00(SB), Y7, Y11 VPXOR Y9, Y0, Y0 VPXOR Y10, Y1, Y1 @@ -3622,8 +3597,8 @@ TEXT ·p256PointDoubleAsm(SB),NOSPLIT,$256-16 calY() \ storeTmpY() \ -//func p256PointDouble5TimesAsm(res, in *SM2P256Point) -TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$256-16 +//func p256PointDouble6TimesAsm(res, in *SM2P256Point) +TEXT ·p256PointDouble6TimesAsm(SB),NOSPLIT,$256-16 // Move input to stack in order to free registers MOVQ res+0(FP), AX MOVQ in+8(FP), BX @@ -3632,7 +3607,8 @@ TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$256-16 // Store pointer to result MOVQ AX, rptr - // Begin point double 1-4 rounds + // Begin point double 1-5 rounds + p256PointDoubleRound() p256PointDoubleRound() p256PointDoubleRound() p256PointDoubleRound() diff --git a/internal/sm2ec/p256_asm_arm64.s b/internal/sm2ec/p256_asm_arm64.s index 015b023..5cf41a3 100644 --- a/internal/sm2ec/p256_asm_arm64.s +++ b/internal/sm2ec/p256_asm_arm64.s @@ -291,8 +291,9 @@ TEXT ·p256FromMont(SB),NOSPLIT,$0 RET /* ---------------------------------------*/ -// func p256Select(res *SM2P256Point, table *p256Table, idx int) +// func p256Select(res *SM2P256Point, table *p256Table, idx, limit int) TEXT ·p256Select(SB),NOSPLIT,$0 + MOVD limit+24(FP), const3 MOVD idx+16(FP), const0 MOVD table+8(FP), b_ptr MOVD res+0(FP), res_ptr @@ -334,7 +335,7 @@ loop_select: CSEL EQ, acc2, t2, t2 CSEL EQ, acc3, t3, t3 - CMP $16, const1 + CMP const3, const1 BNE loop_select STP (x0, x1), 0*16(res_ptr) @@ -1619,7 +1620,8 @@ TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$136-16 CALL sm2P256Subinternal<>(SB) STx(y3out) - // Begin point double rounds 2 - 5 + // Begin point double rounds 2 - 6 + p256PointDoubleRound() p256PointDoubleRound() p256PointDoubleRound() p256PointDoubleRound() diff --git a/internal/sm2ec/sm2ec_test.go b/internal/sm2ec/sm2ec_test.go index 23c4dca..11ae4ed 100644 --- a/internal/sm2ec/sm2ec_test.go +++ b/internal/sm2ec/sm2ec_test.go @@ -2,6 +2,7 @@ package sm2ec import ( "bytes" + "crypto/rand" "encoding/hex" "fmt" "math/big" @@ -14,6 +15,7 @@ var r0 = bigFromHex("010000000000000000") var sm2Prime = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF") var sm2n = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123") var nistP256Prime = bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951") +var nistP256N = bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369") func generateMontgomeryDomain(in *big.Int, p *big.Int) *big.Int { tmp := new(big.Int) @@ -237,6 +239,9 @@ func TestScalarMult(t *testing.T) { t.Run("N+1", func(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes()) }) + t.Run("N+58", func(t *testing.T) { + checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(58)).Bytes()) + }) t.Run("all1s", func(t *testing.T) { s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen)) s.Sub(s, big.NewInt(1)) @@ -256,6 +261,7 @@ func TestScalarMult(t *testing.T) { checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen))) }) } + // Test N-64...N+64 since they risk overlapping with precomputed table values // in the final additions. for i := int64(-64); i <= 64; i++ { @@ -263,6 +269,7 @@ func TestScalarMult(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) }) } + } func fatalIfErr(t *testing.T, err error) { @@ -271,3 +278,25 @@ func fatalIfErr(t *testing.T, err error) { t.Fatal(err) } } + +func BenchmarkScalarBaseMult(b *testing.B) { + p := NewSM2P256Point().SetGenerator() + scalar := make([]byte, 32) + rand.Read(scalar) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + p.ScalarBaseMult(scalar) + } +} + +func BenchmarkScalarMult(b *testing.B) { + p := NewSM2P256Point().SetGenerator() + scalar := make([]byte, 32) + rand.Read(scalar) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + p.ScalarMult(p, scalar) + } +} diff --git a/internal/sm2ec/sm2p256_asm.go b/internal/sm2ec/sm2p256_asm.go index 6632ca8..c562955 100644 --- a/internal/sm2ec/sm2p256_asm.go +++ b/internal/sm2ec/sm2p256_asm.go @@ -356,13 +356,13 @@ func p256OrdLittleToBig(res *[32]byte, in *p256OrdElement) // p256Table is a table of the first 16 multiples of a point. Points are stored // at an index offset of -1 so [8]P is at index 7, P is at 0, and [16]P is at 15. // [0]P is the point at infinity and it's not stored. -type p256Table [16]SM2P256Point +type p256Table [32]SM2P256Point // p256Select sets res to the point at index idx in the table. -// idx must be in [0, 15]. It executes in constant time. +// idx must be in [0, limit-1]. It executes in constant time. // //go:noescape -func p256Select(res *SM2P256Point, table *p256Table, idx int) +func p256Select(res *SM2P256Point, table *p256Table, idx, limit int) // p256AffinePoint is a point in affine coordinates (x, y). x and y are still // Montgomery domain elements. The point can't be the point at infinity. @@ -416,15 +416,30 @@ func p256PointAddAsm(res, in1, in2 *SM2P256Point) int //go:noescape func p256PointDoubleAsm(res, in *SM2P256Point) -// Point doubling 5 times. in can be the point at infinity. +// Point doubling 6 times. in can be the point at infinity. // //go:noescape -func p256PointDouble5TimesAsm(res, in *SM2P256Point) +func p256PointDouble6TimesAsm(res, in *SM2P256Point) // p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the // Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order. type p256OrdElement [4]uint64 +// p256OrdReduce ensures s is in the range [0, ord(G)-1]. +func p256OrdReduce(s *p256OrdElement) { + // Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G), + // keeping the result if it doesn't underflow. + t0, b := bits.Sub64(s[0], 0x53bbf40939d54123, 0) + t1, b := bits.Sub64(s[1], 0x7203df6b21c6052b, b) + t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b) + t3, b := bits.Sub64(s[3], 0xfffffffeffffffff, b) + tMask := b - 1 // zero if subtraction underflowed + s[0] ^= (t0 ^ s[0]) & tMask + s[1] ^= (t1 ^ s[1]) & tMask + s[2] ^= (t2 ^ s[2]) & tMask + s[3] ^= (t3 ^ s[3]) & tMask +} + // Add sets q = p1 + p2, and returns q. The points may overlap. func (q *SM2P256Point) Add(r1, r2 *SM2P256Point) *SM2P256Point { var sum, double SM2P256Point @@ -454,7 +469,7 @@ func (r *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) { } scalarReversed := new(p256OrdElement) p256OrdBigToLittle(scalarReversed, toElementArray(scalar)) - + p256OrdReduce(scalarReversed) r.p256BaseMult(scalarReversed) return r, nil } @@ -468,7 +483,7 @@ func (r *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point } scalarReversed := new(p256OrdElement) p256OrdBigToLittle(scalarReversed, toElementArray(scalar)) - + p256OrdReduce(scalarReversed) r.Set(q).p256ScalarMult(scalarReversed) return r, nil } @@ -804,10 +819,14 @@ func (p *SM2P256Point) p256BaseMult(scalar *p256OrdElement) { zero := sel for i := 1; i < 43; i++ { - if index < 192 { - wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x7f + if index >= 192 { + wvalue = (scalar[3] >> (index & 63)) & 0x7f + } else if index >= 128 { + wvalue = ((scalar[2] >> (index & 63)) + (scalar[3] << (64 - (index & 63)))) & 0x7f + } else if index >= 64 { + wvalue = ((scalar[1] >> (index & 63)) + (scalar[2] << (64 - (index & 63)))) & 0x7f } else { - wvalue = (scalar[index/64] >> (index % 64)) & 0x7f + wvalue = ((scalar[0] >> (index & 63)) + (scalar[1] << (64 - (index & 63)))) & 0x7f } index += 6 sel, sign = boothW6(uint(wvalue)) @@ -822,90 +841,88 @@ func (p *SM2P256Point) p256BaseMult(scalar *p256OrdElement) { func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) { // precomp is a table of precomputed points that stores powers of p - // from p^1 to p^16. + // from p^1 to p^32. var precomp p256Table - var t0, t1, t2, t3 SM2P256Point + var t0, t1 SM2P256Point // Prepare the table precomp[0] = *p // 1 - p256PointDoubleAsm(&t0, p) - p256PointDoubleAsm(&t1, &t0) - p256PointDoubleAsm(&t2, &t1) - p256PointDoubleAsm(&t3, &t2) - precomp[1] = t0 // 2 - precomp[3] = t1 // 4 - precomp[7] = t2 // 8 - precomp[15] = t3 // 16 + p256PointDoubleAsm(&precomp[1], p) //2 + p256PointAddAsm(&precomp[2], &precomp[1], p) //3 + p256PointDoubleAsm(&precomp[3], &precomp[1]) //4 + p256PointAddAsm(&precomp[4], &precomp[3], p) //5 + p256PointDoubleAsm(&precomp[5], &precomp[2]) //6 + p256PointAddAsm(&precomp[6], &precomp[5], p) //7 + p256PointDoubleAsm(&precomp[7], &precomp[3]) //8 + p256PointAddAsm(&precomp[8], &precomp[7], p) //9 + p256PointDoubleAsm(&precomp[9], &precomp[4]) //10 + p256PointAddAsm(&precomp[10], &precomp[9], p) //11 + p256PointDoubleAsm(&precomp[11], &precomp[5]) //12 + p256PointAddAsm(&precomp[12], &precomp[11], p) //13 + p256PointDoubleAsm(&precomp[13], &precomp[6]) //14 + p256PointAddAsm(&precomp[14], &precomp[13], p) //15 + p256PointDoubleAsm(&precomp[15], &precomp[7]) //16 - p256PointAddAsm(&t0, &t0, p) - p256PointAddAsm(&t1, &t1, p) - p256PointAddAsm(&t2, &t2, p) - precomp[2] = t0 // 3 - precomp[4] = t1 // 5 - precomp[8] = t2 // 9 - - p256PointDoubleAsm(&t0, &t0) - p256PointDoubleAsm(&t1, &t1) - precomp[5] = t0 // 6 - precomp[9] = t1 // 10 - - p256PointAddAsm(&t2, &t0, p) - p256PointAddAsm(&t1, &t1, p) - precomp[6] = t2 // 7 - precomp[10] = t1 // 11 - - p256PointDoubleAsm(&t0, &t0) - p256PointDoubleAsm(&t2, &t2) - precomp[11] = t0 // 12 - precomp[13] = t2 // 14 - - p256PointAddAsm(&t0, &t0, p) - p256PointAddAsm(&t2, &t2, p) - precomp[12] = t0 // 13 - precomp[14] = t2 // 15 + p256PointAddAsm(&precomp[16], &precomp[15], p) //17 + p256PointDoubleAsm(&precomp[17], &precomp[8]) //18 + p256PointAddAsm(&precomp[18], &precomp[17], p) //19 + p256PointDoubleAsm(&precomp[19], &precomp[9]) //20 + p256PointAddAsm(&precomp[20], &precomp[19], p) //21 + p256PointDoubleAsm(&precomp[21], &precomp[10]) //22 + p256PointAddAsm(&precomp[22], &precomp[21], p) //23 + p256PointDoubleAsm(&precomp[23], &precomp[11]) //24 + p256PointAddAsm(&precomp[24], &precomp[23], p) //25 + p256PointDoubleAsm(&precomp[25], &precomp[12]) //26 + p256PointAddAsm(&precomp[26], &precomp[25], p) //27 + p256PointDoubleAsm(&precomp[27], &precomp[13]) //28 + p256PointAddAsm(&precomp[28], &precomp[27], p) //29 + p256PointDoubleAsm(&precomp[29], &precomp[14]) //30 + p256PointAddAsm(&precomp[30], &precomp[29], p) //31 + p256PointDoubleAsm(&precomp[31], &precomp[15]) //32 // Start scanning the window from top bit - index := uint(254) + index := uint(251) var sel, sign int - wvalue := (scalar[index/64] >> (index % 64)) & 0x3f - sel, _ = boothW5(uint(wvalue)) + wvalue := (scalar[index/64] >> (index % 64)) & 0x7f + sel, _ = boothW6(uint(wvalue)) - p256Select(p, &precomp, sel) + p256Select(p, &precomp, sel, 32) zero := sel - for index > 4 { - index -= 5 + for index > 5 { + index -= 6 - p256PointDouble5TimesAsm(p, p) + p256PointDouble6TimesAsm(p, p) - if index < 192 { - wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x3f + if index >= 192 { + wvalue = (scalar[3] >> (index & 63)) & 0x7f + } else if index >= 128 { + wvalue = ((scalar[2] >> (index & 63)) + (scalar[3] << (64 - (index & 63)))) & 0x7f + } else if index >= 64 { + wvalue = ((scalar[1] >> (index & 63)) + (scalar[2] << (64 - (index & 63)))) & 0x7f } else { - wvalue = (scalar[index/64] >> (index % 64)) & 0x3f + wvalue = ((scalar[0] >> (index & 63)) + (scalar[1] << (64 - (index & 63)))) & 0x7f } - sel, sign = boothW5(uint(wvalue)) + sel, sign = boothW6(uint(wvalue)) - p256Select(&t0, &precomp, sel) + p256Select(&t0, &precomp, sel, 32) p256NegCond(&t0.y, sign) p256PointAddAsm(&t1, p, &t0) p256MovCond(&t1, &t1, p, sel) p256MovCond(p, &t1, &t0, zero) zero |= sel } - p256PointDouble5TimesAsm(p, p) + p256PointDouble6TimesAsm(p, p) - wvalue = (scalar[0] << 1) & 0x3f - sel, sign = boothW5(uint(wvalue)) + wvalue = (scalar[0] << 1) & 0x7f + sel, sign = boothW6(uint(wvalue)) - p256Select(&t0, &precomp, sel) + p256Select(&t0, &precomp, sel, 32) p256NegCond(&t0.y, sign) - // t0 = p when scalar = N - 6 - pointsEqual := p256PointAddAsm(&t1, p, &t0) - p256PointDoubleAsm(&t2, p) - p256MovCond(&t1, &t2, &t1, pointsEqual) + p256PointAddAsm(&t1, p, &t0) p256MovCond(&t1, &t1, p, sel) p256MovCond(p, &t1, &t0, zero) }