diff --git a/sm2/sm2.go b/sm2/sm2.go index 0f856b9..82b2a9a 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -149,12 +149,20 @@ func (*SM2SignerOption) HashFunc() crypto.Hash { // FromECPrivateKey convert an ecdsa private key to SM2 private key func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { if key.Curve != P256() { - return nil, errors.New("It's NOT a sm2 curve private key") + return nil, errors.New("SM2: it's NOT a sm2 curve private key") } priv.PrivateKey = *key return priv, nil } +func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool { + xx, ok := x.(*PrivateKey) + if !ok { + return false + } + return priv.PublicKey.Equal(&xx.PublicKey) && priv.D.Cmp(xx.D) == 0 +} + // Sign signs digest with priv, reading randomness from rand. The opts argument // is not currently used but, in keeping with the crypto.Signer interface, // should be the hash function used to digest the message. diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 07e0578..6ea1b2e 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -1,6 +1,7 @@ package sm2 import ( + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -304,6 +305,30 @@ func TestINDCCA(t *testing.T) { } } +func TestEqual(t *testing.T) { + private, _ := GenerateKey(rand.Reader) + public := &private.PublicKey + + if !public.Equal(public) { + t.Errorf("public key is not equal to itself: %q", public) + } + if !public.Equal(crypto.Signer(private).Public()) { + t.Errorf("private.Public() is not Equal to public: %q", public) + } + if !private.Equal(private) { + t.Errorf("private key is not equal to itself: %q", private) + } + + otherPriv, _ := GenerateKey(rand.Reader) + otherPub := &otherPriv.PublicKey + if public.Equal(otherPub) { + t.Errorf("different public keys are Equal") + } + if private.Equal(otherPriv) { + t.Errorf("different private keys are Equal") + } +} + func BenchmarkGenerateKey_SM2(b *testing.B) { b.ReportAllocs() b.ResetTimer() @@ -403,8 +428,13 @@ func BenchmarkVerify_SM2(b *testing.B) { } func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) { + priv, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() for i := 0; i < b.N; i++ { - priv, _ := ecdsa.GenerateKey(curve, rand.Reader) Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil) } }