key exchange refactoring, reduce duplicated code

This commit is contained in:
Sun Yimin 2022-06-20 09:42:48 +08:00 committed by GitHub
parent 23914a86c3
commit e06e5ef47b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 181 deletions

View File

@ -23,9 +23,9 @@ type KeyExchange struct {
genSignature bool // control the optional sign/verify step triggered by responsder genSignature bool // control the optional sign/verify step triggered by responsder
keyLength int // key length keyLength int // key length
privateKey *PrivateKey // owner's encryption private key privateKey *PrivateKey // owner's encryption private key
uid []byte // owner uid z []byte // owner identifiable id
peerUID []byte // peer uid
peerPub *ecdsa.PublicKey // peer public key peerPub *ecdsa.PublicKey // peer public key
peerZ []byte // peer identifiable id
r *big.Int // random which will be used to compute secret r *big.Int // random which will be used to compute secret
secret Point // generated secret which will be passed to peer secret Point // generated secret which will be passed to peer
peerSecret Point // received peer's secret peerSecret Point // received peer's secret
@ -41,14 +41,12 @@ func (ke *KeyExchange) GetKey() []byte {
} }
// NewKeyExchange create one new KeyExchange object // NewKeyExchange create one new KeyExchange object
func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []byte, keyLen int, genSignature bool) *KeyExchange { func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []byte, keyLen int, genSignature bool) (ke *KeyExchange, err error) {
ke := &KeyExchange{} ke = &KeyExchange{}
ke.genSignature = genSignature ke.genSignature = genSignature
ke.peerPub = peerPub ke.peerPub = peerPub
ke.keyLength = keyLen ke.keyLength = keyLen
ke.privateKey = priv ke.privateKey = priv
ke.uid = uid
ke.peerUID = peerUID
w := (priv.Params().N.BitLen()+1)/2 - 1 w := (priv.Params().N.BitLen()+1)/2 - 1
x2 := big.NewInt(2) x2 := big.NewInt(2)
ke.w2 = x2 ke.w2 = x2
@ -56,7 +54,15 @@ func NewKeyExchange(priv *PrivateKey, peerPub *ecdsa.PublicKey, uid, peerUID []b
x2minus1 := (&big.Int{}).Sub(x2, big.NewInt(1)) x2minus1 := (&big.Int{}).Sub(x2, big.NewInt(1))
ke.w2Minus1 = x2minus1 ke.w2Minus1 = x2minus1
return ke ke.z, err = calculateZA(&ke.privateKey.PublicKey, uid)
if err != nil {
return nil, err
}
ke.peerZ, err = calculateZA(ke.peerPub, peerUID)
if err != nil {
return nil, err
}
return
} }
func initKeyExchange(ke *KeyExchange, r *big.Int) { func initKeyExchange(ke *KeyExchange, r *big.Int) {
@ -74,10 +80,52 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader) (*Point, error) {
return &ke.secret, nil return &ke.secret, nil
} }
func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
var buffer []byte
hash := sm3.New()
hash.Write(toBytes(ke.privateKey, ke.v.X))
if isResponder {
hash.Write(ke.peerZ)
hash.Write(ke.z)
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
} else {
hash.Write(ke.z)
hash.Write(ke.peerZ)
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
}
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{prefix})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
return hash.Sum(nil)
}
func (ke *KeyExchange) generateKey(isResponder bool) {
var buffer []byte
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
if isResponder {
buffer = append(buffer, ke.peerZ...)
buffer = append(buffer, ke.z...)
} else {
buffer = append(buffer, ke.z...)
buffer = append(buffer, ke.peerZ...)
}
key, _ := sm3.Kdf(buffer, ke.keyLength)
ke.key = key
}
func respondKeyExchange(ke *KeyExchange, r *big.Int, rA *Point) (*Point, []byte, error) { func respondKeyExchange(ke *KeyExchange, r *big.Int, rA *Point) (*Point, []byte, error) {
ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes()) ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes())
ke.r = r ke.r = r
// Calculate tB
t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X)
t.Add(ke.w2, t) t.Add(ke.w2, t)
t.Mul(t, ke.r) t.Mul(t, ke.r)
@ -94,42 +142,13 @@ func respondKeyExchange(ke *KeyExchange, r *big.Int, rA *Point) (*Point, []byte,
return nil, nil, errors.New("sm2: key exchange fail") return nil, nil, errors.New("sm2: key exchange fail")
} }
var buffer []byte ke.generateKey(true)
zA, err := calculateZA(ke.peerPub, ke.peerUID)
if err != nil {
return nil, nil, err
}
zB, err := calculateZA(&ke.privateKey.PublicKey, ke.uid)
if err != nil {
return nil, nil, err
}
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
buffer = append(buffer, zA...)
buffer = append(buffer, zB...)
key, _ := sm3.Kdf(buffer, ke.keyLength)
ke.key = key
if !ke.genSignature { if !ke.genSignature {
return &ke.secret, nil, nil return &ke.secret, nil, nil
} }
hash := sm3.New() return &ke.secret, ke.sign(true, 0x02), nil
hash.Write(toBytes(ke.privateKey, ke.v.X))
hash.Write(zA)
hash.Write(zB)
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x02})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
buffer = hash.Sum(nil)
return &ke.secret, buffer, nil
} }
// RepondKeyExchange when responder receive rA, for responder's step B1-B8 // RepondKeyExchange when responder receive rA, for responder's step B1-B8
@ -150,10 +169,8 @@ func (ke *KeyExchange) ConfirmResponder(rB *Point, sB []byte) ([]byte, error) {
if !ke.privateKey.IsOnCurve(rB.X, rB.Y) { if !ke.privateKey.IsOnCurve(rB.X, rB.Y) {
return nil, errors.New("sm2: received invalid random from responder") return nil, errors.New("sm2: received invalid random from responder")
} }
hash := sm3.New()
ke.peerSecret = *rB ke.peerSecret = *rB
// Calcualte tA
t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X)
t.Add(ke.w2, t) t.Add(ke.w2, t)
t.Mul(t, ke.r) t.Mul(t, ke.r)
@ -170,84 +187,19 @@ func (ke *KeyExchange) ConfirmResponder(rB *Point, sB []byte) ([]byte, error) {
if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 {
return nil, errors.New("sm2: key exchange fail") return nil, errors.New("sm2: key exchange fail")
} }
ke.generateKey(false)
var buffer []byte
zA, err := calculateZA(&ke.privateKey.PublicKey, ke.uid)
if err != nil {
return nil, err
}
zB, err := calculateZA(ke.peerPub, ke.peerUID)
if err != nil {
return nil, err
}
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
buffer = append(buffer, zA...)
buffer = append(buffer, zB...)
key, _ := sm3.Kdf(buffer, ke.keyLength)
ke.key = key
if len(sB) > 0 { if len(sB) > 0 {
hash.Write(toBytes(ke.privateKey, ke.v.X)) buffer := ke.sign(false, 0x02)
hash.Write(zA)
hash.Write(zB)
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x02})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
buffer = hash.Sum(nil)
hash.Reset()
if goSubtle.ConstantTimeCompare(buffer, sB) != 1 { if goSubtle.ConstantTimeCompare(buffer, sB) != 1 {
return nil, errors.New("sm2: verify responder's signature fail") return nil, errors.New("sm2: verify responder's signature fail")
} }
} }
hash.Write(toBytes(ke.privateKey, ke.v.X)) return ke.sign(false, 0x03), nil
hash.Write(zA)
hash.Write(zB)
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x03})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
buffer = hash.Sum(nil)
return buffer, nil
} }
// ConfirmInitiator for responder's step B10 // ConfirmInitiator for responder's step B10
func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error {
hash := sm3.New() buffer := ke.sign(true, 0x03)
var buffer []byte
zB, err := calculateZA(&ke.privateKey.PublicKey, ke.uid)
if err != nil {
return err
}
zA, err := calculateZA(ke.peerPub, ke.peerUID)
if err != nil {
return err
}
hash.Write(toBytes(ke.privateKey, ke.v.X))
hash.Write(zA)
hash.Write(zB)
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
hash.Write(toBytes(ke.privateKey, ke.secret.X))
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x03})
hash.Write(toBytes(ke.privateKey, ke.v.Y))
hash.Write(buffer)
buffer = hash.Sum(nil)
if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { if goSubtle.ConstantTimeCompare(buffer, s1) != 1 {
return errors.New("sm2: verify initiator's signature fail") return errors.New("sm2: verify initiator's signature fail")
} }

View File

@ -547,6 +547,51 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, err
return ke.secret, nil return ke.secret, nil
} }
func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
var buffer []byte
hash := sm3.New()
hash.Write(ke.g2.Marshal())
hash.Write(ke.g3.Marshal())
if isResponder {
hash.Write(ke.peerUID)
hash.Write(ke.uid)
hash.Write(ke.peerSecret.Marshal())
hash.Write(ke.secret.Marshal())
} else {
hash.Write(ke.uid)
hash.Write(ke.peerUID)
hash.Write(ke.secret.Marshal())
hash.Write(ke.peerSecret.Marshal())
}
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{prefix})
hash.Write(ke.g1.Marshal())
hash.Write(buffer)
return hash.Sum(nil)
}
func (ke *KeyExchange) generateKey(isResponder bool) {
var buffer []byte
if isResponder {
buffer = append(buffer, ke.peerUID...)
buffer = append(buffer, ke.uid...)
buffer = append(buffer, ke.peerSecret.Marshal()...)
buffer = append(buffer, ke.secret.Marshal()...)
} else {
buffer = append(buffer, ke.uid...)
buffer = append(buffer, ke.peerUID...)
buffer = append(buffer, ke.secret.Marshal()...)
buffer = append(buffer, ke.peerSecret.Marshal()...)
}
buffer = append(buffer, ke.g1.Marshal()...)
buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...)
key, _ := sm3.Kdf(buffer, ke.keyLength)
ke.key = key
}
func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) {
if !rA.IsOnCurve() { if !rA.IsOnCurve() {
return nil, nil, errors.New("sm9: received invalid random from initiator") return nil, nil, errors.New("sm9: received invalid random from initiator")
@ -562,36 +607,13 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b
ke.g3.ScalarMult(ke.g1, r) ke.g3.ScalarMult(ke.g1, r)
ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r)
var buffer []byte ke.generateKey(true)
buffer = append(buffer, ke.peerUID...)
buffer = append(buffer, ke.uid...)
buffer = append(buffer, ke.peerSecret.Marshal()...)
buffer = append(buffer, ke.secret.Marshal()...)
buffer = append(buffer, ke.g1.Marshal()...)
buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...)
key, _ := sm3.Kdf(buffer, ke.keyLength)
ke.key = key
if !ke.genSignature { if !ke.genSignature {
return ke.secret, nil, nil return ke.secret, nil, nil
} }
hash := sm3.New() return ke.secret, ke.sign(true, 0x82), nil
hash.Write(ke.g2.Marshal())
hash.Write(ke.g3.Marshal())
hash.Write(ke.peerUID)
hash.Write(ke.uid)
hash.Write(ke.peerSecret.Marshal())
hash.Write(ke.secret.Marshal())
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x82})
hash.Write(ke.g1.Marshal())
hash.Write(buffer)
buffer = hash.Sum(nil)
return ke.secret, buffer, nil
} }
// RepondKeyExchange when responder receive rA, for responder's step B1-B7 // RepondKeyExchange when responder receive rA, for responder's step B1-B7
@ -608,7 +630,6 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error)
if !rB.IsOnCurve() { if !rB.IsOnCurve() {
return nil, errors.New("sm9: received invalid random from responder") return nil, errors.New("sm9: received invalid random from responder")
} }
hash := sm3.New()
// step 5 // step 5
ke.peerSecret = rB ke.peerSecret = rB
ke.g1 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r) ke.g1 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r)
@ -616,60 +637,20 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error)
ke.g3 = &bn256.GT{} ke.g3 = &bn256.GT{}
ke.g3.ScalarMult(ke.g2, ke.r) ke.g3.ScalarMult(ke.g2, ke.r)
// step 6, verify signature // step 6, verify signature
var temp []byte
var buffer []byte
if len(sB) > 0 { if len(sB) > 0 {
hash.Write(ke.g2.Marshal()) signature := ke.sign(false, 0x82)
hash.Write(ke.g3.Marshal())
hash.Write(ke.uid)
hash.Write(ke.peerUID)
hash.Write(ke.secret.Marshal())
hash.Write(ke.peerSecret.Marshal())
temp = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x82})
hash.Write(ke.g1.Marshal())
hash.Write(temp)
signature := hash.Sum(nil)
hash.Reset()
if goSubtle.ConstantTimeCompare(signature, sB) != 1 { if goSubtle.ConstantTimeCompare(signature, sB) != 1 {
return nil, errors.New("sm9: verify responder's signature fail") return nil, errors.New("sm9: verify responder's signature fail")
} }
} }
buffer = append(buffer, ke.uid...) ke.generateKey(false)
buffer = append(buffer, ke.peerUID...)
buffer = append(buffer, ke.secret.Marshal()...)
buffer = append(buffer, ke.peerSecret.Marshal()...)
buffer = append(buffer, ke.g1.Marshal()...)
buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...)
key, _ := sm3.Kdf(buffer, ke.keyLength) return ke.sign(false, 0x83), nil
ke.key = key
hash.Write([]byte{0x83})
hash.Write(ke.g1.Marshal())
hash.Write(temp)
return hash.Sum(nil), nil
} }
// ConfirmInitiator for responder's step B8 // ConfirmInitiator for responder's step B8
func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error {
hash := sm3.New() buffer := ke.sign(true, 0x83)
var buffer []byte
hash.Write(ke.g2.Marshal())
hash.Write(ke.g3.Marshal())
hash.Write(ke.peerUID)
hash.Write(ke.uid)
hash.Write(ke.peerSecret.Marshal())
hash.Write(ke.secret.Marshal())
buffer = hash.Sum(nil)
hash.Reset()
hash.Write([]byte{0x83})
hash.Write(ke.g1.Marshal())
hash.Write(buffer)
buffer = hash.Sum(nil)
if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { if goSubtle.ConstantTimeCompare(buffer, s1) != 1 {
return errors.New("sm9: verify initiator's signature fail") return errors.New("sm9: verify initiator's signature fail")
} }