sm2ec: p256ScalarMult change to use w=6

This commit is contained in:
Sun Yimin 2023-06-14 17:30:58 +08:00 committed by GitHub
parent 7f54c1e1a5
commit de14139590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 121 deletions

View File

@ -152,13 +152,9 @@ move_avx2:
VPANDN (32*1)(x_ptr), Y12, Y1
VPANDN (32*2)(x_ptr), Y12, Y2
VMOVDQU (32*0)(y_ptr), Y3
VMOVDQU (32*1)(y_ptr), Y4
VMOVDQU (32*2)(y_ptr), Y5
VPAND Y12, Y3, Y3
VPAND Y12, Y4, Y4
VPAND Y12, Y5, Y5
VPAND (32*0)(y_ptr), Y12, Y3
VPAND (32*1)(y_ptr), Y12, Y4
VPAND (32*2)(y_ptr), Y12, Y5
VPXOR Y3, Y0, Y0
VPXOR Y4, Y1, Y1
@ -963,7 +959,7 @@ TEXT ·p256FromMont(SB),NOSPLIT,$0
RET
/* ---------------------------------------*/
// func p256Select(res *SM2P256Point, table *p256Table, idx int)
// func p256Select(res *SM2P256Point, table *p256Table, idx, limit int)
TEXT ·p256Select(SB),NOSPLIT,$0
//MOVQ idx+16(FP),AX
MOVQ table+8(FP),DI
@ -984,7 +980,7 @@ TEXT ·p256Select(SB),NOSPLIT,$0
PXOR X3, X3
PXOR X4, X4
PXOR X5, X5
MOVQ $16, AX
MOVQ limit+24(FP),AX
MOVOU X15, X13
@ -1035,7 +1031,7 @@ select_avx2:
MOVL idx+16(FP), X14 // x14 = idx
VPBROADCASTD X14, Y14
MOVQ $16, AX
MOVQ limit+24(FP),AX
VMOVDQU Y15, Y13
VPXOR Y0, Y0, Y0
@ -1047,16 +1043,12 @@ loop_select_avx2:
VPADDD Y15, Y13, Y13
VPCMPEQD Y14, Y12, Y12
VMOVDQU (32*0)(DI), Y3
VMOVDQU (32*1)(DI), Y4
VMOVDQU (32*2)(DI), Y5
VPAND (32*0)(DI), Y12, Y3
VPAND (32*1)(DI), Y12, Y4
VPAND (32*2)(DI), Y12, Y5
ADDQ $(32*3), DI
VPAND Y12, Y3, Y3
VPAND Y12, Y4, Y4
VPAND Y12, Y5, Y5
VPXOR Y3, Y0, Y0
VPXOR Y4, Y1, Y1
VPXOR Y5, Y2, Y2
@ -1163,22 +1155,17 @@ loop_select_base_avx2:
VPADDD Y15, Y13, Y13
VPCMPEQD Y14, Y12, Y12
VMOVDQU (32*0)(DI), Y2
VMOVDQU (32*1)(DI), Y3
VMOVDQU (32*2)(DI), Y4
VMOVDQU (32*3)(DI), Y5
ADDQ $(32*4), DI
VPAND Y12, Y2, Y2
VPAND Y12, Y3, Y3
VPAND (32*0)(DI), Y12, Y2
VPAND (32*1)(DI), Y12, Y3
VMOVDQU Y13, Y12
VPADDD Y15, Y13, Y13
VPCMPEQD Y14, Y12, Y12
VPAND Y12, Y4, Y4
VPAND Y12, Y5, Y5
VPAND (32*2)(DI), Y12, Y4
VPAND (32*3)(DI), Y12, Y5
ADDQ $(32*4), DI
VPXOR Y2, Y0, Y0
VPXOR Y3, Y1, Y1
@ -3097,10 +3084,6 @@ pointaddaffine_avx2:
p256PointAddAffineInline()
// The result is not valid if (sel == 0), conditional choose
VMOVDQU xout(32*0), Y0
VMOVDQU yout(32*0), Y1
VMOVDQU zout(32*0), Y2
MOVL BX, X6
MOVL CX, X7
@ -3116,17 +3099,13 @@ pointaddaffine_avx2:
VMOVDQU Y6, Y15
VPANDN Y9, Y15, Y15
VMOVDQU x1in(32*0), Y9
VMOVDQU y1in(32*0), Y10
VMOVDQU z1in(32*0), Y11
VPAND xout(32*0), Y15, Y0
VPAND yout(32*0), Y15, Y1
VPAND zout(32*0), Y15, Y2
VPAND Y15, Y0, Y0
VPAND Y15, Y1, Y1
VPAND Y15, Y2, Y2
VPAND Y6, Y9, Y9
VPAND Y6, Y10, Y10
VPAND Y6, Y11, Y11
VPAND x1in(32*0), Y6, Y9
VPAND y1in(32*0), Y6, Y10
VPAND z1in(32*0), Y6, Y11
VPXOR Y9, Y0, Y0
VPXOR Y10, Y1, Y1
@ -3136,17 +3115,13 @@ pointaddaffine_avx2:
VPCMPEQD Y9, Y9, Y9
VPANDN Y9, Y7, Y15
VMOVDQU x2in(32*0), Y9
VMOVDQU y2in(32*0), Y10
VMOVDQU p256one<>+0x00(SB), Y11
VPAND Y15, Y0, Y0
VPAND Y15, Y1, Y1
VPAND Y15, Y2, Y2
VPAND Y7, Y9, Y9
VPAND Y7, Y10, Y10
VPAND Y7, Y11, Y11
VPAND x2in(32*0), Y7, Y9
VPAND y2in(32*0), Y7, Y10
VPAND p256one<>+0x00(SB), Y7, Y11
VPXOR Y9, Y0, Y0
VPXOR Y10, Y1, Y1
@ -3622,8 +3597,8 @@ TEXT ·p256PointDoubleAsm(SB),NOSPLIT,$256-16
calY() \
storeTmpY() \
//func p256PointDouble5TimesAsm(res, in *SM2P256Point)
TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$256-16
//func p256PointDouble6TimesAsm(res, in *SM2P256Point)
TEXT ·p256PointDouble6TimesAsm(SB),NOSPLIT,$256-16
// Move input to stack in order to free registers
MOVQ res+0(FP), AX
MOVQ in+8(FP), BX
@ -3632,7 +3607,8 @@ TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$256-16
// Store pointer to result
MOVQ AX, rptr
// Begin point double 1-4 rounds
// Begin point double 1-5 rounds
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()

View File

@ -291,8 +291,9 @@ TEXT ·p256FromMont(SB),NOSPLIT,$0
RET
/* ---------------------------------------*/
// func p256Select(res *SM2P256Point, table *p256Table, idx int)
// func p256Select(res *SM2P256Point, table *p256Table, idx, limit int)
TEXT ·p256Select(SB),NOSPLIT,$0
MOVD limit+24(FP), const3
MOVD idx+16(FP), const0
MOVD table+8(FP), b_ptr
MOVD res+0(FP), res_ptr
@ -334,7 +335,7 @@ loop_select:
CSEL EQ, acc2, t2, t2
CSEL EQ, acc3, t3, t3
CMP $16, const1
CMP const3, const1
BNE loop_select
STP (x0, x1), 0*16(res_ptr)
@ -1619,7 +1620,8 @@ TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$136-16
CALL sm2P256Subinternal<>(SB)
STx(y3out)
// Begin point double rounds 2 - 5
// Begin point double rounds 2 - 6
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()

View File

@ -2,6 +2,7 @@ package sm2ec
import (
"bytes"
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
@ -14,6 +15,7 @@ var r0 = bigFromHex("010000000000000000")
var sm2Prime = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF")
var sm2n = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123")
var nistP256Prime = bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951")
var nistP256N = bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369")
func generateMontgomeryDomain(in *big.Int, p *big.Int) *big.Int {
tmp := new(big.Int)
@ -237,6 +239,9 @@ func TestScalarMult(t *testing.T) {
t.Run("N+1", func(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes())
})
t.Run("N+58", func(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(58)).Bytes())
})
t.Run("all1s", func(t *testing.T) {
s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
s.Sub(s, big.NewInt(1))
@ -256,6 +261,7 @@ func TestScalarMult(t *testing.T) {
checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
})
}
// Test N-64...N+64 since they risk overlapping with precomputed table values
// in the final additions.
for i := int64(-64); i <= 64; i++ {
@ -263,6 +269,7 @@ func TestScalarMult(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes())
})
}
}
func fatalIfErr(t *testing.T, err error) {
@ -271,3 +278,25 @@ func fatalIfErr(t *testing.T, err error) {
t.Fatal(err)
}
}
func BenchmarkScalarBaseMult(b *testing.B) {
p := NewSM2P256Point().SetGenerator()
scalar := make([]byte, 32)
rand.Read(scalar)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.ScalarBaseMult(scalar)
}
}
func BenchmarkScalarMult(b *testing.B) {
p := NewSM2P256Point().SetGenerator()
scalar := make([]byte, 32)
rand.Read(scalar)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.ScalarMult(p, scalar)
}
}

View File

@ -356,13 +356,13 @@ func p256OrdLittleToBig(res *[32]byte, in *p256OrdElement)
// p256Table is a table of the first 16 multiples of a point. Points are stored
// at an index offset of -1 so [8]P is at index 7, P is at 0, and [16]P is at 15.
// [0]P is the point at infinity and it's not stored.
type p256Table [16]SM2P256Point
type p256Table [32]SM2P256Point
// p256Select sets res to the point at index idx in the table.
// idx must be in [0, 15]. It executes in constant time.
// idx must be in [0, limit-1]. It executes in constant time.
//
//go:noescape
func p256Select(res *SM2P256Point, table *p256Table, idx int)
func p256Select(res *SM2P256Point, table *p256Table, idx, limit int)
// p256AffinePoint is a point in affine coordinates (x, y). x and y are still
// Montgomery domain elements. The point can't be the point at infinity.
@ -416,15 +416,30 @@ func p256PointAddAsm(res, in1, in2 *SM2P256Point) int
//go:noescape
func p256PointDoubleAsm(res, in *SM2P256Point)
// Point doubling 5 times. in can be the point at infinity.
// Point doubling 6 times. in can be the point at infinity.
//
//go:noescape
func p256PointDouble5TimesAsm(res, in *SM2P256Point)
func p256PointDouble6TimesAsm(res, in *SM2P256Point)
// p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the
// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
type p256OrdElement [4]uint64
// p256OrdReduce ensures s is in the range [0, ord(G)-1].
func p256OrdReduce(s *p256OrdElement) {
// Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G),
// keeping the result if it doesn't underflow.
t0, b := bits.Sub64(s[0], 0x53bbf40939d54123, 0)
t1, b := bits.Sub64(s[1], 0x7203df6b21c6052b, b)
t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b)
t3, b := bits.Sub64(s[3], 0xfffffffeffffffff, b)
tMask := b - 1 // zero if subtraction underflowed
s[0] ^= (t0 ^ s[0]) & tMask
s[1] ^= (t1 ^ s[1]) & tMask
s[2] ^= (t2 ^ s[2]) & tMask
s[3] ^= (t3 ^ s[3]) & tMask
}
// Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *SM2P256Point) Add(r1, r2 *SM2P256Point) *SM2P256Point {
var sum, double SM2P256Point
@ -454,7 +469,7 @@ func (r *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) {
}
scalarReversed := new(p256OrdElement)
p256OrdBigToLittle(scalarReversed, toElementArray(scalar))
p256OrdReduce(scalarReversed)
r.p256BaseMult(scalarReversed)
return r, nil
}
@ -468,7 +483,7 @@ func (r *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point
}
scalarReversed := new(p256OrdElement)
p256OrdBigToLittle(scalarReversed, toElementArray(scalar))
p256OrdReduce(scalarReversed)
r.Set(q).p256ScalarMult(scalarReversed)
return r, nil
}
@ -804,10 +819,14 @@ func (p *SM2P256Point) p256BaseMult(scalar *p256OrdElement) {
zero := sel
for i := 1; i < 43; i++ {
if index < 192 {
wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x7f
if index >= 192 {
wvalue = (scalar[3] >> (index & 63)) & 0x7f
} else if index >= 128 {
wvalue = ((scalar[2] >> (index & 63)) + (scalar[3] << (64 - (index & 63)))) & 0x7f
} else if index >= 64 {
wvalue = ((scalar[1] >> (index & 63)) + (scalar[2] << (64 - (index & 63)))) & 0x7f
} else {
wvalue = (scalar[index/64] >> (index % 64)) & 0x7f
wvalue = ((scalar[0] >> (index & 63)) + (scalar[1] << (64 - (index & 63)))) & 0x7f
}
index += 6
sel, sign = boothW6(uint(wvalue))
@ -822,90 +841,88 @@ func (p *SM2P256Point) p256BaseMult(scalar *p256OrdElement) {
func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) {
// precomp is a table of precomputed points that stores powers of p
// from p^1 to p^16.
// from p^1 to p^32.
var precomp p256Table
var t0, t1, t2, t3 SM2P256Point
var t0, t1 SM2P256Point
// Prepare the table
precomp[0] = *p // 1
p256PointDoubleAsm(&t0, p)
p256PointDoubleAsm(&t1, &t0)
p256PointDoubleAsm(&t2, &t1)
p256PointDoubleAsm(&t3, &t2)
precomp[1] = t0 // 2
precomp[3] = t1 // 4
precomp[7] = t2 // 8
precomp[15] = t3 // 16
p256PointDoubleAsm(&precomp[1], p) //2
p256PointAddAsm(&precomp[2], &precomp[1], p) //3
p256PointDoubleAsm(&precomp[3], &precomp[1]) //4
p256PointAddAsm(&precomp[4], &precomp[3], p) //5
p256PointDoubleAsm(&precomp[5], &precomp[2]) //6
p256PointAddAsm(&precomp[6], &precomp[5], p) //7
p256PointDoubleAsm(&precomp[7], &precomp[3]) //8
p256PointAddAsm(&precomp[8], &precomp[7], p) //9
p256PointDoubleAsm(&precomp[9], &precomp[4]) //10
p256PointAddAsm(&precomp[10], &precomp[9], p) //11
p256PointDoubleAsm(&precomp[11], &precomp[5]) //12
p256PointAddAsm(&precomp[12], &precomp[11], p) //13
p256PointDoubleAsm(&precomp[13], &precomp[6]) //14
p256PointAddAsm(&precomp[14], &precomp[13], p) //15
p256PointDoubleAsm(&precomp[15], &precomp[7]) //16
p256PointAddAsm(&t0, &t0, p)
p256PointAddAsm(&t1, &t1, p)
p256PointAddAsm(&t2, &t2, p)
precomp[2] = t0 // 3
precomp[4] = t1 // 5
precomp[8] = t2 // 9
p256PointDoubleAsm(&t0, &t0)
p256PointDoubleAsm(&t1, &t1)
precomp[5] = t0 // 6
precomp[9] = t1 // 10
p256PointAddAsm(&t2, &t0, p)
p256PointAddAsm(&t1, &t1, p)
precomp[6] = t2 // 7
precomp[10] = t1 // 11
p256PointDoubleAsm(&t0, &t0)
p256PointDoubleAsm(&t2, &t2)
precomp[11] = t0 // 12
precomp[13] = t2 // 14
p256PointAddAsm(&t0, &t0, p)
p256PointAddAsm(&t2, &t2, p)
precomp[12] = t0 // 13
precomp[14] = t2 // 15
p256PointAddAsm(&precomp[16], &precomp[15], p) //17
p256PointDoubleAsm(&precomp[17], &precomp[8]) //18
p256PointAddAsm(&precomp[18], &precomp[17], p) //19
p256PointDoubleAsm(&precomp[19], &precomp[9]) //20
p256PointAddAsm(&precomp[20], &precomp[19], p) //21
p256PointDoubleAsm(&precomp[21], &precomp[10]) //22
p256PointAddAsm(&precomp[22], &precomp[21], p) //23
p256PointDoubleAsm(&precomp[23], &precomp[11]) //24
p256PointAddAsm(&precomp[24], &precomp[23], p) //25
p256PointDoubleAsm(&precomp[25], &precomp[12]) //26
p256PointAddAsm(&precomp[26], &precomp[25], p) //27
p256PointDoubleAsm(&precomp[27], &precomp[13]) //28
p256PointAddAsm(&precomp[28], &precomp[27], p) //29
p256PointDoubleAsm(&precomp[29], &precomp[14]) //30
p256PointAddAsm(&precomp[30], &precomp[29], p) //31
p256PointDoubleAsm(&precomp[31], &precomp[15]) //32
// Start scanning the window from top bit
index := uint(254)
index := uint(251)
var sel, sign int
wvalue := (scalar[index/64] >> (index % 64)) & 0x3f
sel, _ = boothW5(uint(wvalue))
wvalue := (scalar[index/64] >> (index % 64)) & 0x7f
sel, _ = boothW6(uint(wvalue))
p256Select(p, &precomp, sel)
p256Select(p, &precomp, sel, 32)
zero := sel
for index > 4 {
index -= 5
for index > 5 {
index -= 6
p256PointDouble5TimesAsm(p, p)
p256PointDouble6TimesAsm(p, p)
if index < 192 {
wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x3f
if index >= 192 {
wvalue = (scalar[3] >> (index & 63)) & 0x7f
} else if index >= 128 {
wvalue = ((scalar[2] >> (index & 63)) + (scalar[3] << (64 - (index & 63)))) & 0x7f
} else if index >= 64 {
wvalue = ((scalar[1] >> (index & 63)) + (scalar[2] << (64 - (index & 63)))) & 0x7f
} else {
wvalue = (scalar[index/64] >> (index % 64)) & 0x3f
wvalue = ((scalar[0] >> (index & 63)) + (scalar[1] << (64 - (index & 63)))) & 0x7f
}
sel, sign = boothW5(uint(wvalue))
sel, sign = boothW6(uint(wvalue))
p256Select(&t0, &precomp, sel)
p256Select(&t0, &precomp, sel, 32)
p256NegCond(&t0.y, sign)
p256PointAddAsm(&t1, p, &t0)
p256MovCond(&t1, &t1, p, sel)
p256MovCond(p, &t1, &t0, zero)
zero |= sel
}
p256PointDouble5TimesAsm(p, p)
p256PointDouble6TimesAsm(p, p)
wvalue = (scalar[0] << 1) & 0x3f
sel, sign = boothW5(uint(wvalue))
wvalue = (scalar[0] << 1) & 0x7f
sel, sign = boothW6(uint(wvalue))
p256Select(&t0, &precomp, sel)
p256Select(&t0, &precomp, sel, 32)
p256NegCond(&t0.y, sign)
// t0 = p when scalar = N - 6
pointsEqual := p256PointAddAsm(&t1, p, &t0)
p256PointDoubleAsm(&t2, p)
p256MovCond(&t1, &t2, &t1, pointsEqual)
p256PointAddAsm(&t1, p, &t0)
p256MovCond(&t1, &t1, p, sel)
p256MovCond(p, &t1, &t0, zero)
}