sm2ec: fix ScalarMult issue when scalar = N - 6

This commit is contained in:
Sun Yimin 2023-06-13 13:24:43 +08:00 committed by GitHub
parent 40dba3a488
commit 7f54c1e1a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 4 deletions

View File

@ -227,6 +227,9 @@ func TestScalarMult(t *testing.T) {
t.Run("1", func(t *testing.T) { t.Run("1", func(t *testing.T) {
checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
}) })
t.Run("N-6", func(t *testing.T) {
checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(6)).Bytes())
})
t.Run("N-1", func(t *testing.T) { t.Run("N-1", func(t *testing.T) {
checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes()) checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes())
}) })
@ -248,10 +251,15 @@ func TestScalarMult(t *testing.T) {
checkScalar(t, s.FillBytes(make([]byte, byteLen))) checkScalar(t, s.FillBytes(make([]byte, byteLen)))
}) })
} }
// Test N+1...N+32 since they risk overlapping with precomputed table values for i := 0; i <= 64; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
})
}
// Test N-64...N+64 since they risk overlapping with precomputed table values
// in the final additions. // in the final additions.
for i := int64(2); i <= 32; i++ { for i := int64(-64); i <= 64; i++ {
t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes())
}) })
} }

View File

@ -902,7 +902,10 @@ func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) {
p256Select(&t0, &precomp, sel) p256Select(&t0, &precomp, sel)
p256NegCond(&t0.y, sign) p256NegCond(&t0.y, sign)
p256PointAddAsm(&t1, p, &t0) // t0 = p when scalar = N - 6
pointsEqual := p256PointAddAsm(&t1, p, &t0)
p256PointDoubleAsm(&t2, p)
p256MovCond(&t1, &t2, &t1, pointsEqual)
p256MovCond(&t1, &t1, p, sel) p256MovCond(&t1, &t1, p, sel)
p256MovCond(p, &t1, &t0, zero) p256MovCond(p, &t1, &t0, zero)
} }