mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
MAGIC - avoid endless loop, validate public key
This commit is contained in:
parent
0702a8a2ac
commit
2dd11a9e9a
20
sm2/sm2.go
20
sm2/sm2.go
@ -153,6 +153,8 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
const maxRetryLimit = 100
|
||||
|
||||
func kdf(z []byte, len int) ([]byte, bool) {
|
||||
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
||||
md := sm3.New()
|
||||
@ -201,14 +203,22 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
||||
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
|
||||
|
||||
//A3, skipped
|
||||
//A3, requirement is to check if h*P is infinite point, h is 1
|
||||
if !curve.IsOnCurve(pub.X, pub.Y) {
|
||||
return nil, errors.New("SM2: invalid public key")
|
||||
}
|
||||
|
||||
//A4, calculate k * P (point of Public Key)
|
||||
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||
|
||||
//A5, calculate t=KDF(x2||y2, klen)
|
||||
var kdfCount int = 0
|
||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||
if !success {
|
||||
fmt.Println("A5, failed to get valid t")
|
||||
kdfCount++
|
||||
if kdfCount > maxRetryLimit {
|
||||
return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@ -245,7 +255,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
||||
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
ciphertextLen := len(ciphertext)
|
||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||
return nil, errors.New("invalid ciphertext length")
|
||||
return nil, errors.New("SM2: invalid ciphertext length")
|
||||
}
|
||||
curve := priv.Curve
|
||||
// B1, get C1, and check C1
|
||||
@ -263,7 +273,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
msgLen := len(c2)
|
||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||
if !success {
|
||||
return nil, errors.New("invalid cipher text")
|
||||
return nil, errors.New("SM2: invalid cipher text")
|
||||
}
|
||||
|
||||
//B5, calculate msg = c2 ^ t
|
||||
@ -277,7 +287,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
u := calculateC3(curve, x2, y2, msg)
|
||||
for i := 0; i < sm3.Size; i++ {
|
||||
if c3[i] != u[i] {
|
||||
return nil, errors.New("invalid hash value")
|
||||
return nil, errors.New("SM2: invalid hash value")
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user