diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index e000d55..c0af2a5 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -23,9 +23,9 @@ type KeyExchange struct { genSignature bool // control the optional sign/verify step triggered by responsder keyLength int // key length privateKey *PrivateKey // owner's encryption private key - uid []byte // owner uid - peerUID []byte // peer uid + z []byte // owner identifiable id peerPub *ecdsa.PublicKey // peer public key + peerZ []byte // peer identifiable id r *big.Int // random which will be used to compute secret secret Point // generated secret which will be passed to peer peerSecret Point // received peer's secret @@ -41,14 +41,12 @@ func (ke *KeyExchange) GetKey() []byte { } // NewKeyExchange create one new KeyExchange object -func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []byte, keyLen int, genSignature bool) *KeyExchange { - ke := &KeyExchange{} +func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []byte, keyLen int, genSignature bool) (ke *KeyExchange, err error) { + ke = &KeyExchange{} ke.genSignature = genSignature ke.peerPub = peerPub ke.keyLength = keyLen ke.privateKey = priv - ke.uid = uid - ke.peerUID = peerUID w := (priv.Params().N.BitLen()+1)/2 - 1 x2 := big.NewInt(2) ke.w2 = x2 @@ -56,7 +54,15 @@ func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []b x2minus1 := (&big.Int{}).Sub(x2, big.NewInt(1)) ke.w2Minus1 = x2minus1 - return ke + ke.z, err = calculateZA(&ke.privateKey.PublicKey, uid) + if err != nil { + return nil, err + } + ke.peerZ, err = calculateZA(ke.peerPub, peerUID) + if err != nil { + return nil, err + } + return } func initKeyExchange(ke *KeyExchange, r *big.Int) { @@ -74,10 +80,52 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader) (*Point, error) { return &ke.secret, nil } +func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte { + var buffer []byte + hash := sm3.New() + hash.Write(toBytes(ke.privateKey, ke.v.X)) + if isResponder { + hash.Write(ke.peerZ) + hash.Write(ke.z) + hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) + hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) + hash.Write(toBytes(ke.privateKey, ke.secret.X)) + hash.Write(toBytes(ke.privateKey, ke.secret.Y)) + } else { + hash.Write(ke.z) + hash.Write(ke.peerZ) + hash.Write(toBytes(ke.privateKey, ke.secret.X)) + hash.Write(toBytes(ke.privateKey, ke.secret.Y)) + hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) + hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) + } + buffer = hash.Sum(nil) + hash.Reset() + hash.Write([]byte{prefix}) + hash.Write(toBytes(ke.privateKey, ke.v.Y)) + hash.Write(buffer) + return hash.Sum(nil) +} + +func (ke *KeyExchange) generateKey(isResponder bool) { + var buffer []byte + buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) + buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) + if isResponder { + buffer = append(buffer, ke.peerZ...) + buffer = append(buffer, ke.z...) + } else { + buffer = append(buffer, ke.z...) + buffer = append(buffer, ke.peerZ...) + } + key, _ := sm3.Kdf(buffer, ke.keyLength) + ke.key = key +} + func respondKeyExchange(ke *KeyExchange, r *big.Int, rA *Point) (*Point, []byte, error) { ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes()) ke.r = r - + // Calculate tB t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) t.Add(ke.w2, t) t.Mul(t, ke.r) @@ -94,42 +142,13 @@ func respondKeyExchange(ke *KeyExchange, r *big.Int, rA *Point) (*Point, []byte, return nil, nil, errors.New("sm2: key exchange fail") } - var buffer []byte - zA, err := calculateZA(ke.peerPub, ke.peerUID) - if err != nil { - return nil, nil, err - } - zB, err := calculateZA(&ke.privateKey.PublicKey, ke.uid) - if err != nil { - return nil, nil, err - } - buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) - buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) - buffer = append(buffer, zA...) - buffer = append(buffer, zB...) - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key + ke.generateKey(true) if !ke.genSignature { return &ke.secret, nil, nil } - hash := sm3.New() - hash.Write(toBytes(ke.privateKey, ke.v.X)) - hash.Write(zA) - hash.Write(zB) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x02}) - hash.Write(toBytes(ke.privateKey, ke.v.Y)) - hash.Write(buffer) - buffer = hash.Sum(nil) - - return &ke.secret, buffer, nil + return &ke.secret, ke.sign(true, 0x02), nil } // RepondKeyExchange when responder receive rA, for responder's step B1-B8 @@ -150,10 +169,8 @@ func (ke *KeyExchange) ConfirmResponder(rB *Point, sB []byte) ([]byte, error) { if !ke.privateKey.IsOnCurve(rB.X, rB.Y) { return nil, errors.New("sm2: received invalid random from responder") } - hash := sm3.New() - ke.peerSecret = *rB - + // Calcualte tA t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) t.Add(ke.w2, t) t.Mul(t, ke.r) @@ -170,84 +187,19 @@ func (ke *KeyExchange) ConfirmResponder(rB *Point, sB []byte) ([]byte, error) { if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { return nil, errors.New("sm2: key exchange fail") } - - var buffer []byte - zA, err := calculateZA(&ke.privateKey.PublicKey, ke.uid) - if err != nil { - return nil, err - } - zB, err := calculateZA(ke.peerPub, ke.peerUID) - if err != nil { - return nil, err - } - buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) - buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) - buffer = append(buffer, zA...) - buffer = append(buffer, zB...) - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key - + ke.generateKey(false) if len(sB) > 0 { - hash.Write(toBytes(ke.privateKey, ke.v.X)) - hash.Write(zA) - hash.Write(zB) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x02}) - hash.Write(toBytes(ke.privateKey, ke.v.Y)) - hash.Write(buffer) - buffer = hash.Sum(nil) - hash.Reset() + buffer := ke.sign(false, 0x02) if goSubtle.ConstantTimeCompare(buffer, sB) != 1 { return nil, errors.New("sm2: verify responder's signature fail") } } - hash.Write(toBytes(ke.privateKey, ke.v.X)) - hash.Write(zA) - hash.Write(zB) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x03}) - hash.Write(toBytes(ke.privateKey, ke.v.Y)) - hash.Write(buffer) - buffer = hash.Sum(nil) - - return buffer, nil + return ke.sign(false, 0x03), nil } // ConfirmInitiator for responder's step B10 func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { - hash := sm3.New() - var buffer []byte - zB, err := calculateZA(&ke.privateKey.PublicKey, ke.uid) - if err != nil { - return err - } - zA, err := calculateZA(ke.peerPub, ke.peerUID) - if err != nil { - return err - } - hash.Write(toBytes(ke.privateKey, ke.v.X)) - hash.Write(zA) - hash.Write(zB) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.X)) - hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y)) - hash.Write(toBytes(ke.privateKey, ke.secret.X)) - hash.Write(toBytes(ke.privateKey, ke.secret.Y)) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x03}) - hash.Write(toBytes(ke.privateKey, ke.v.Y)) - hash.Write(buffer) - buffer = hash.Sum(nil) + buffer := ke.sign(true, 0x03) if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { return errors.New("sm2: verify initiator's signature fail") } diff --git a/sm9/sm9.go b/sm9/sm9.go index 8b4add7..e3f03b7 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -547,6 +547,51 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, err return ke.secret, nil } +func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte { + var buffer []byte + hash := sm3.New() + hash.Write(ke.g2.Marshal()) + hash.Write(ke.g3.Marshal()) + if isResponder { + hash.Write(ke.peerUID) + hash.Write(ke.uid) + hash.Write(ke.peerSecret.Marshal()) + hash.Write(ke.secret.Marshal()) + } else { + hash.Write(ke.uid) + hash.Write(ke.peerUID) + hash.Write(ke.secret.Marshal()) + hash.Write(ke.peerSecret.Marshal()) + } + buffer = hash.Sum(nil) + hash.Reset() + hash.Write([]byte{prefix}) + hash.Write(ke.g1.Marshal()) + hash.Write(buffer) + return hash.Sum(nil) +} + +func (ke *KeyExchange) generateKey(isResponder bool) { + var buffer []byte + if isResponder { + buffer = append(buffer, ke.peerUID...) + buffer = append(buffer, ke.uid...) + buffer = append(buffer, ke.peerSecret.Marshal()...) + buffer = append(buffer, ke.secret.Marshal()...) + } else { + buffer = append(buffer, ke.uid...) + buffer = append(buffer, ke.peerUID...) + buffer = append(buffer, ke.secret.Marshal()...) + buffer = append(buffer, ke.peerSecret.Marshal()...) + } + buffer = append(buffer, ke.g1.Marshal()...) + buffer = append(buffer, ke.g2.Marshal()...) + buffer = append(buffer, ke.g3.Marshal()...) + + key, _ := sm3.Kdf(buffer, ke.keyLength) + ke.key = key +} + func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { if !rA.IsOnCurve() { return nil, nil, errors.New("sm9: received invalid random from initiator") @@ -562,36 +607,13 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b ke.g3.ScalarMult(ke.g1, r) ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) - var buffer []byte - buffer = append(buffer, ke.peerUID...) - buffer = append(buffer, ke.uid...) - buffer = append(buffer, ke.peerSecret.Marshal()...) - buffer = append(buffer, ke.secret.Marshal()...) - buffer = append(buffer, ke.g1.Marshal()...) - buffer = append(buffer, ke.g2.Marshal()...) - buffer = append(buffer, ke.g3.Marshal()...) - - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key + ke.generateKey(true) if !ke.genSignature { return ke.secret, nil, nil } - hash := sm3.New() - hash.Write(ke.g2.Marshal()) - hash.Write(ke.g3.Marshal()) - hash.Write(ke.peerUID) - hash.Write(ke.uid) - hash.Write(ke.peerSecret.Marshal()) - hash.Write(ke.secret.Marshal()) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x82}) - hash.Write(ke.g1.Marshal()) - hash.Write(buffer) - buffer = hash.Sum(nil) - return ke.secret, buffer, nil + return ke.secret, ke.sign(true, 0x82), nil } // RepondKeyExchange when responder receive rA, for responder's step B1-B7 @@ -608,7 +630,6 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) if !rB.IsOnCurve() { return nil, errors.New("sm9: received invalid random from responder") } - hash := sm3.New() // step 5 ke.peerSecret = rB ke.g1 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r) @@ -616,60 +637,20 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) ke.g3 = &bn256.GT{} ke.g3.ScalarMult(ke.g2, ke.r) // step 6, verify signature - var temp []byte - var buffer []byte if len(sB) > 0 { - hash.Write(ke.g2.Marshal()) - hash.Write(ke.g3.Marshal()) - hash.Write(ke.uid) - hash.Write(ke.peerUID) - hash.Write(ke.secret.Marshal()) - hash.Write(ke.peerSecret.Marshal()) - temp = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x82}) - hash.Write(ke.g1.Marshal()) - hash.Write(temp) - signature := hash.Sum(nil) - hash.Reset() + signature := ke.sign(false, 0x82) if goSubtle.ConstantTimeCompare(signature, sB) != 1 { return nil, errors.New("sm9: verify responder's signature fail") } } - buffer = append(buffer, ke.uid...) - buffer = append(buffer, ke.peerUID...) - buffer = append(buffer, ke.secret.Marshal()...) - buffer = append(buffer, ke.peerSecret.Marshal()...) - buffer = append(buffer, ke.g1.Marshal()...) - buffer = append(buffer, ke.g2.Marshal()...) - buffer = append(buffer, ke.g3.Marshal()...) + ke.generateKey(false) - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key - - hash.Write([]byte{0x83}) - hash.Write(ke.g1.Marshal()) - hash.Write(temp) - - return hash.Sum(nil), nil + return ke.sign(false, 0x83), nil } // ConfirmInitiator for responder's step B8 func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { - hash := sm3.New() - var buffer []byte - hash.Write(ke.g2.Marshal()) - hash.Write(ke.g3.Marshal()) - hash.Write(ke.peerUID) - hash.Write(ke.uid) - hash.Write(ke.peerSecret.Marshal()) - hash.Write(ke.secret.Marshal()) - buffer = hash.Sum(nil) - hash.Reset() - hash.Write([]byte{0x83}) - hash.Write(ke.g1.Marshal()) - hash.Write(buffer) - buffer = hash.Sum(nil) + buffer := ke.sign(true, 0x83) if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { return errors.New("sm9: verify initiator's signature fail") }