diff --git a/sm2/sm2.go b/sm2/sm2.go index cad4f3c..0bb2688 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -192,6 +192,10 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter if opts == nil { opts = &defaultEncrypterOpts } + //A3, requirement is to check if h*P is infinite point, h is 1 + if (pub.X.Sign() == 0 && pub.Y.Sign() == 0) || !curve.IsOnCurve(pub.X, pub.Y) { + return nil, errors.New("SM2: invalid public key") + } for { //A1, generate random k k, err := randFieldElement(curve, random) @@ -203,11 +207,6 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter x1, y1 := curve.ScalarBaseMult(k.Bytes()) c1 := opts.PointMarshalMode.mashal(curve, x1, y1) - //A3, requirement is to check if h*P is infinite point, h is 1 - if !curve.IsOnCurve(pub.X, pub.Y) { - return nil, errors.New("SM2: invalid public key") - } - //A4, calculate k * P (point of Public Key) x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())