sm2: ErrDecryption, avoid adaptive attacks

This commit is contained in:
Sun Yimin 2022-12-06 08:39:16 +08:00 committed by GitHub
parent 60c3caf9db
commit 32acdfea7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 9 deletions

View File

@ -317,6 +317,10 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return decrypt(priv, ciphertext, nil) return decrypt(priv, ciphertext, nil)
} }
// ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks.
var ErrDecryption = errors.New("sm2: decryption error")
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
ciphertextLen := len(ciphertext) ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
@ -333,22 +337,22 @@ func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte,
func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
C1, c2, c3, err := parseCiphertext(c, ciphertext, opts) C1, c2, c3, err := parseCiphertext(c, ciphertext, opts)
if err != nil { if err != nil {
return nil, err return nil, ErrDecryption
} }
d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
if err != nil { if err != nil {
return nil, err return nil, ErrDecryption
} }
C2, err := C1.ScalarMult(C1, d.Bytes(c.N)) C2, err := C1.ScalarMult(C1, d.Bytes(c.N))
if err != nil { if err != nil {
return nil, err return nil, ErrDecryption
} }
C2Bytes := C2.Bytes()[1:] C2Bytes := C2.Bytes()[1:]
msgLen := len(c2) msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen) msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
return nil, errors.New("sm2: invalid cipher text") return nil, ErrDecryption
} }
//B5, calculate msg = c2 ^ t //B5, calculate msg = c2 ^ t
@ -363,7 +367,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp
if _subtle.ConstantTimeCompare(u, c3) == 1 { if _subtle.ConstantTimeCompare(u, c3) == 1 {
return msg, nil return msg, nil
} }
return nil, errors.New("sm2: invalid plaintext digest") return nil, ErrDecryption
} }
func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) { func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) {

View File

@ -399,7 +399,7 @@ func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicin
func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) { func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
if err != nil { if err != nil {
return nil, err return nil, ErrDecryption
} }
return rawDecrypt(priv, x1, y1, c2, c3) return rawDecrypt(priv, x1, y1, c2, c3)
} }
@ -410,7 +410,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
msgLen := len(c2) msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
return nil, errors.New("sm2: invalid cipher text") return nil, ErrDecryption
} }
//B5, calculate msg = c2 ^ t //B5, calculate msg = c2 ^ t
@ -420,7 +420,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
if _subtle.ConstantTimeCompare(u, c3) == 1 { if _subtle.ConstantTimeCompare(u, c3) == 1 {
return msg, nil return msg, nil
} }
return nil, errors.New("sm2: invalid plaintext digest") return nil, ErrDecryption
} }
func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
@ -439,7 +439,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
// B1, get C1, and check C1 // B1, get C1, and check C1
x1, y1, c3Start, err := bytes2Point(curve, ciphertext) x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
if err != nil { if err != nil {
return nil, err return nil, ErrDecryption
} }
//B4, calculate t=KDF(x2||y2, klen) //B4, calculate t=KDF(x2||y2, klen)