MAGIC - avoid endless loop, validate public key

This commit is contained in:
Emman 2021-02-25 10:53:53 +08:00
parent 0702a8a2ac
commit 2dd11a9e9a

View File

@ -153,6 +153,8 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
}
///////////////////////////////////////////////////////////////////////////////////
const maxRetryLimit = 100
func kdf(z []byte, len int) ([]byte, bool) {
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
md := sm3.New()
@ -201,14 +203,22 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
x1, y1 := curve.ScalarBaseMult(k.Bytes())
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
//A3, skipped
//A3, requirement is to check if h*P is infinite point, h is 1
if !curve.IsOnCurve(pub.X, pub.Y) {
return nil, errors.New("SM2: invalid public key")
}
//A4, calculate k * P (point of Public Key)
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
//A5, calculate t=KDF(x2||y2, klen)
var kdfCount int = 0
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success {
fmt.Println("A5, failed to get valid t")
kdfCount++
if kdfCount > maxRetryLimit {
return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount)
}
continue
}
@ -245,7 +255,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errors.New("invalid ciphertext length")
return nil, errors.New("SM2: invalid ciphertext length")
}
curve := priv.Curve
// B1, get C1, and check C1
@ -263,7 +273,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
msgLen := len(c2)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success {
return nil, errors.New("invalid cipher text")
return nil, errors.New("SM2: invalid cipher text")
}
//B5, calculate msg = c2 ^ t
@ -277,7 +287,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
u := calculateC3(curve, x2, y2, msg)
for i := 0; i < sm3.Size; i++ {
if c3[i] != u[i] {
return nil, errors.New("invalid hash value")
return nil, errors.New("SM2: invalid hash value")
}
}