diff --git a/internal/sm2ec/sm2ec_test.go b/internal/sm2ec/sm2ec_test.go index 74dbe1b..ec8b01e 100644 --- a/internal/sm2ec/sm2ec_test.go +++ b/internal/sm2ec/sm2ec_test.go @@ -201,6 +201,41 @@ func TestEquivalents(t *testing.T) { } } +func TestBasicScalarMult(t *testing.T) { + testvector := []struct { + name string + scalar *big.Int + expected string + }{ + { + "32", + big.NewInt(32), + "0425d3debd0950d180a6d5c2b5817f2329791734cd03e5565ca32641e56024666c92d99a70679d61efb938c406dd5cb0e10458895120e208b4d39e100303fa10a2", + }, + { + "N-3", + new(big.Int).Sub(sm2n, big.NewInt(3)), + "04a97f7cd4b3c993b4be2daa8cdb41e24ca13f6bd945302244e26918f1d0509ebfacf4a2267397710a333a313f758deaf083bff11932fbad6e555322fc8ba70919", + }, + } + p := NewSM2P256Point().SetGenerator() + + for _, test := range testvector { + scalar := make([]byte, 32) + test.scalar.FillBytes(scalar) + p1, err := NewSM2P256Point().ScalarBaseMult(scalar) + fatalIfErr(t, err) + p2, err := NewSM2P256Point().ScalarMult(p, scalar) + fatalIfErr(t, err) + if hex.EncodeToString(p1.Bytes()) != test.expected { + t.Errorf("%s ScalarBaseMult fail, got %x", test.name, p1.Bytes()) + } + if hex.EncodeToString(p2.Bytes()) != test.expected { + t.Errorf("%s ScalarMult fail, got %x", test.name, p2.Bytes()) + } + } +} + func TestScalarMult(t *testing.T) { G := NewSM2P256Point().SetGenerator() checkScalar := func(t *testing.T, scalar []byte) { @@ -261,7 +296,7 @@ func TestScalarMult(t *testing.T) { checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen))) }) } - + // Test N-64...N+64 since they risk overlapping with precomputed table values // in the final additions. for i := int64(-64); i <= 64; i++ { @@ -269,7 +304,7 @@ func TestScalarMult(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) }) } - + } func fatalIfErr(t *testing.T, err error) {