mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-21 17:56:19 +08:00
MAGIC - avoid endless loop, validate public key
This commit is contained in:
parent
0702a8a2ac
commit
2dd11a9e9a
20
sm2/sm2.go
20
sm2/sm2.go
@ -153,6 +153,8 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////
|
||||||
|
const maxRetryLimit = 100
|
||||||
|
|
||||||
func kdf(z []byte, len int) ([]byte, bool) {
|
func kdf(z []byte, len int) ([]byte, bool) {
|
||||||
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
limit := (len + sm3.Size - 1) >> sm3.SizeBitSize
|
||||||
md := sm3.New()
|
md := sm3.New()
|
||||||
@ -201,14 +203,22 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
|||||||
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||||
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
|
c1 := opts.PointMarshalMode.mashal(curve, x1, y1)
|
||||||
|
|
||||||
//A3, skipped
|
//A3, requirement is to check if h*P is infinite point, h is 1
|
||||||
|
if !curve.IsOnCurve(pub.X, pub.Y) {
|
||||||
|
return nil, errors.New("SM2: invalid public key")
|
||||||
|
}
|
||||||
|
|
||||||
//A4, calculate k * P (point of Public Key)
|
//A4, calculate k * P (point of Public Key)
|
||||||
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||||
|
|
||||||
//A5, calculate t=KDF(x2||y2, klen)
|
//A5, calculate t=KDF(x2||y2, klen)
|
||||||
|
var kdfCount int = 0
|
||||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
if !success {
|
if !success {
|
||||||
fmt.Println("A5, failed to get valid t")
|
kdfCount++
|
||||||
|
if kdfCount > maxRetryLimit {
|
||||||
|
return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,7 +255,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
|||||||
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||||
ciphertextLen := len(ciphertext)
|
ciphertextLen := len(ciphertext)
|
||||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||||
return nil, errors.New("invalid ciphertext length")
|
return nil, errors.New("SM2: invalid ciphertext length")
|
||||||
}
|
}
|
||||||
curve := priv.Curve
|
curve := priv.Curve
|
||||||
// B1, get C1, and check C1
|
// B1, get C1, and check C1
|
||||||
@ -263,7 +273,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
if !success {
|
if !success {
|
||||||
return nil, errors.New("invalid cipher text")
|
return nil, errors.New("SM2: invalid cipher text")
|
||||||
}
|
}
|
||||||
|
|
||||||
//B5, calculate msg = c2 ^ t
|
//B5, calculate msg = c2 ^ t
|
||||||
@ -277,7 +287,7 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|||||||
u := calculateC3(curve, x2, y2, msg)
|
u := calculateC3(curve, x2, y2, msg)
|
||||||
for i := 0; i < sm3.Size; i++ {
|
for i := 0; i < sm3.Size; i++ {
|
||||||
if c3[i] != u[i] {
|
if c3[i] != u[i] {
|
||||||
return nil, errors.New("invalid hash value")
|
return nil, errors.New("SM2: invalid hash value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user