diff --git a/sm2/sm2.go b/sm2/sm2.go index 81dce98..cad4f3c 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -153,6 +153,8 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) } /////////////////////////////////////////////////////////////////////////////////// +const maxRetryLimit = 100 + func kdf(z []byte, len int) ([]byte, bool) { limit := (len + sm3.Size - 1) >> sm3.SizeBitSize md := sm3.New() @@ -201,14 +203,22 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter x1, y1 := curve.ScalarBaseMult(k.Bytes()) c1 := opts.PointMarshalMode.mashal(curve, x1, y1) - //A3, skipped + //A3, requirement is to check if h*P is infinite point, h is 1 + if !curve.IsOnCurve(pub.X, pub.Y) { + return nil, errors.New("SM2: invalid public key") + } + //A4, calculate k * P (point of Public Key) x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) //A5, calculate t=KDF(x2||y2, klen) + var kdfCount int = 0 t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { - fmt.Println("A5, failed to get valid t") + kdfCount++ + if kdfCount > maxRetryLimit { + return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount) + } continue } @@ -245,7 +255,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { - return nil, errors.New("invalid ciphertext length") + return nil, errors.New("SM2: invalid ciphertext length") } curve := priv.Curve // B1, get C1, and check C1 @@ -263,7 +273,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { msgLen := len(c2) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { - return nil, errors.New("invalid cipher text") + return nil, errors.New("SM2: invalid cipher text") } //B5, calculate msg = c2 ^ t @@ -277,7 +287,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { if c3[i] != u[i] { - return nil, errors.New("invalid hash value") + return nil, errors.New("SM2: invalid hash value") } }