diff --git a/internal/sm2ec/sm2ec_test.go b/internal/sm2ec/sm2ec_test.go index c143db1..23c4dca 100644 --- a/internal/sm2ec/sm2ec_test.go +++ b/internal/sm2ec/sm2ec_test.go @@ -227,6 +227,9 @@ func TestScalarMult(t *testing.T) { t.Run("1", func(t *testing.T) { checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) }) + t.Run("N-6", func(t *testing.T) { + checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(6)).Bytes()) + }) t.Run("N-1", func(t *testing.T) { checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes()) }) @@ -248,10 +251,15 @@ func TestScalarMult(t *testing.T) { checkScalar(t, s.FillBytes(make([]byte, byteLen))) }) } - // Test N+1...N+32 since they risk overlapping with precomputed table values + for i := 0; i <= 64; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen))) + }) + } + // Test N-64...N+64 since they risk overlapping with precomputed table values // in the final additions. - for i := int64(2); i <= 32; i++ { - t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) { + for i := int64(-64); i <= 64; i++ { + t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) }) } diff --git a/internal/sm2ec/sm2p256_asm.go b/internal/sm2ec/sm2p256_asm.go index c3fdfcd..6632ca8 100644 --- a/internal/sm2ec/sm2p256_asm.go +++ b/internal/sm2ec/sm2p256_asm.go @@ -902,7 +902,10 @@ func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) { p256Select(&t0, &precomp, sel) p256NegCond(&t0.y, sign) - p256PointAddAsm(&t1, p, &t0) + // t0 = p when scalar = N - 6 + pointsEqual := p256PointAddAsm(&t1, p, &t0) + p256PointDoubleAsm(&t2, p) + p256MovCond(&t1, &t2, &t1, pointsEqual) p256MovCond(&t1, &t1, p, sel) p256MovCond(p, &t1, &t0, zero) }