sm2/sm9: key exchange, support to destroy internal state

This commit is contained in:
Sun Yimin 2022-08-24 15:15:58 +08:00 committed by GitHub
parent 35154c8b53
commit 21a9793600
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 207 additions and 105 deletions

View File

@ -26,12 +26,47 @@ type KeyExchange struct {
w2 *big.Int // internal state which will be used when compute the key and signature, 2^w w2 *big.Int // internal state which will be used when compute the key and signature, 2^w
w2Minus1 *big.Int // internal state which will be used when compute the key and signature, 2^w 1 w2Minus1 *big.Int // internal state which will be used when compute the key and signature, 2^w 1
v *ecdsa.PublicKey // internal state which will be used when compute the key and signature, u/v v *ecdsa.PublicKey // internal state which will be used when compute the key and signature, u/v
key []byte // shared key will be used after key agreement
} }
// GetSharedKey return shared key after key agreement // Destroy clear all internal state and Ephemeral private/public keys.
func (ke *KeyExchange) GetSharedKey() []byte { func (ke *KeyExchange) Destroy() {
return ke.key if ke.z != nil {
for v := range ke.z {
ke.z[v] = 0
}
}
if ke.peerZ != nil {
for v := range ke.peerZ {
ke.peerZ[v] = 0
}
}
if ke.r != nil {
ke.r.SetInt64(0)
}
if ke.secret != nil {
if ke.secret.X != nil {
ke.secret.X.SetInt64(0)
}
if ke.secret.Y != nil {
ke.secret.Y.SetInt64(0)
}
}
if ke.peerSecret != nil {
if ke.peerSecret.X != nil {
ke.peerSecret.X.SetInt64(0)
}
if ke.peerSecret.Y != nil {
ke.peerSecret.Y.SetInt64(0)
}
}
if ke.v != nil {
if ke.v.X != nil {
ke.v.X.SetInt64(0)
}
if ke.v.Y != nil {
ke.v.Y.SetInt64(0)
}
}
} }
// NewKeyExchange create one new KeyExchange object // NewKeyExchange create one new KeyExchange object
@ -150,7 +185,7 @@ func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
return hash.Sum(nil) return hash.Sum(nil)
} }
func (ke *KeyExchange) generateSharedKey(isResponder bool) { func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
var buffer []byte var buffer []byte
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
@ -161,39 +196,54 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) {
buffer = append(buffer, ke.z...) buffer = append(buffer, ke.z...)
buffer = append(buffer, ke.peerZ...) buffer = append(buffer, ke.peerZ...)
} }
key, _ := sm3.Kdf(buffer, ke.keyLength) key, ok := sm3.Kdf(buffer, ke.keyLength)
ke.key = key if !ok {
return nil, errors.New("sm2: internal error, kdf failed")
}
return key, nil
}
// avf is the associative value function.
func (ke *KeyExchange) avf(x *big.Int) *big.Int {
t := (&big.Int{}).And(ke.w2Minus1, x)
t.Add(ke.w2, t)
return t
}
func (ke *KeyExchange) implicitSig() []byte {
// Calculate x2`
t := ke.avf(ke.secret.X)
// Calculate tB
t.Mul(t, ke.r)
t.Add(t, ke.privateKey.D)
t.Mod(t, ke.privateKey.Params().N)
return t.Bytes()
}
func (ke *KeyExchange) basePoint() (*big.Int, *big.Int) {
// x1` = 2^w + (x & (2^w 1))
x1 := ke.avf(ke.peerSecret.X)
// Point(x3, y3) = peerPub + [x1](peerSecret)
x, y := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x1.Bytes())
x, y = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x, y)
return x, y
} }
func respondKeyExchange(ke *KeyExchange, r *big.Int) (*ecdsa.PublicKey, []byte, error) { func respondKeyExchange(ke *KeyExchange, r *big.Int) (*ecdsa.PublicKey, []byte, error) {
// secret = RB = [r]G // secret = RB = [r]G
ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes()) ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes())
ke.r = r ke.r = r
// Calculate x2`
t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X)
t.Add(ke.w2, t)
// Calculate tB
t.Mul(t, ke.r)
t.Add(t, ke.privateKey.D)
t.Mod(t, ke.privateKey.Params().N)
// x1` = 2^w + (x & (2^w 1))
x1 := (&big.Int{}).And(ke.w2Minus1, ke.peerSecret.X)
x1.Add(ke.w2, x1)
// Point(x3, y3) = peerPub + [x1](peerSecret) // Point(x3, y3) = peerPub + [x1](peerSecret)
x3, y3 := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x1.Bytes()) x, y := ke.basePoint()
x3, y3 = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x3, y3)
// V = [h*tB](Point(x3, y3)) // V = [h*tB](Point(x3, y3))
ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x3, y3, t.Bytes()) ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x, y, ke.implicitSig())
if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 {
return nil, nil, errors.New("sm2: key exchange failed, V is infinity point") return nil, nil, errors.New("sm2: key exchange failed, V is infinity point")
} }
ke.generateSharedKey(true)
if !ke.genSignature { if !ke.genSignature {
return ke.secret, nil, nil return ke.secret, nil, nil
} }
@ -220,59 +270,52 @@ func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, rA *ecdsa.PublicKey) (*
return respondKeyExchange(ke, r) return respondKeyExchange(ke, r)
} }
// ConfirmResponder for initiator's step A4-A10, returns optional signature. // ConfirmResponder for initiator's step A4-A10, returns keying data and optional signature.
// //
// It will check if there are peer's public key and validate the peer's Ephemeral Public Key. // It will check if there are peer's public key and validate the peer's Ephemeral Public Key.
// //
// If the peer's signature is not empty, then it will also validate the peer's // If the peer's signature is not empty, then it will also validate the peer's
// signature and return generated signature depends on KeyExchange.genSignature value. // signature and return generated signature depends on KeyExchange.genSignature value.
func (ke *KeyExchange) ConfirmResponder(rB *ecdsa.PublicKey, sB []byte) ([]byte, error) { func (ke *KeyExchange) ConfirmResponder(rB *ecdsa.PublicKey, sB []byte) ([]byte, []byte, error) {
if ke.peerPub == nil { if ke.peerPub == nil {
return nil, errors.New("sm2: no peer public key given") return nil, nil, errors.New("sm2: no peer public key given")
} }
if !ke.privateKey.IsOnCurve(rB.X, rB.Y) { if !ke.privateKey.IsOnCurve(rB.X, rB.Y) {
return nil, errors.New("sm2: invalid responder's ephemeral public key") return nil, nil, errors.New("sm2: invalid responder's ephemeral public key")
} }
ke.peerSecret = rB ke.peerSecret = rB
// Calculate tA
t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X)
t.Add(ke.w2, t)
t.Mul(t, ke.r)
t.Add(t, ke.privateKey.D)
t.Mod(t, ke.privateKey.Params().N)
// x2` = 2^w + (x & (2^w 1))
x2 := (&big.Int{}).And(ke.w2Minus1, ke.peerSecret.X)
x2.Add(ke.w2, x2)
// Point(x3, y3) = peerPub + [x1](peerSecret)
x3, y3 := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x2.Bytes())
x3, y3 = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x3, y3)
x, y := ke.basePoint()
// U = [h*tA](Point(x3, y3)) // U = [h*tA](Point(x3, y3))
ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x3, y3, t.Bytes()) ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x, y, ke.implicitSig())
if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 {
return nil, errors.New("sm2: key exchange failed, U is infinity point") return nil, nil, errors.New("sm2: key exchange failed, U is infinity point")
} }
ke.generateSharedKey(false)
if len(sB) > 0 { if len(sB) > 0 {
buffer := ke.sign(false, 0x02) buffer := ke.sign(false, 0x02)
if subtle.ConstantTimeCompare(buffer, sB) != 1 { if subtle.ConstantTimeCompare(buffer, sB) != 1 {
return nil, errors.New("sm2: invalid responder's signature") return nil, nil, errors.New("sm2: invalid responder's signature")
} }
} }
if !ke.genSignature { key, err := ke.generateSharedKey(false)
return nil, nil if err != nil {
return nil, nil, err
} }
return ke.sign(false, 0x03), nil
if !ke.genSignature {
return key, nil, nil
}
return key, ke.sign(false, 0x03), nil
} }
// ConfirmInitiator for responder's step B10 // ConfirmInitiator for responder's step B10
func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { func (ke *KeyExchange) ConfirmInitiator(s1 []byte) ([]byte, error) {
buffer := ke.sign(true, 0x03) if s1 != nil {
if subtle.ConstantTimeCompare(buffer, s1) != 1 { buffer := ke.sign(true, 0x03)
return errors.New("sm2: invalid initiator's signature") if subtle.ConstantTimeCompare(buffer, s1) != 1 {
return nil, errors.New("sm2: invalid initiator's signature")
}
} }
return nil return ke.generateSharedKey(true)
} }

View File

@ -20,6 +20,10 @@ func TestKeyExchangeSample(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() {
initiator.Destroy()
responder.Destroy()
}()
rA, err := initiator.InitKeyExchange(rand.Reader) rA, err := initiator.InitKeyExchange(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -30,17 +34,17 @@ func TestKeyExchangeSample(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s1, err := initiator.ConfirmResponder(rB, s2) key1, s1, err := initiator.ConfirmResponder(rB, s2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = responder.ConfirmInitiator(s1) key2, err := responder.ConfirmInitiator(s1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }
@ -56,6 +60,10 @@ func TestKeyExchangeSimplest(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() {
initiator.Destroy()
responder.Destroy()
}()
rA, err := initiator.InitKeyExchange(rand.Reader) rA, err := initiator.InitKeyExchange(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -69,7 +77,7 @@ func TestKeyExchangeSimplest(t *testing.T) {
t.Errorf("should be no siganature") t.Errorf("should be no siganature")
} }
s1, err := initiator.ConfirmResponder(rB, nil) key1, s1, err := initiator.ConfirmResponder(rB, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -77,7 +85,12 @@ func TestKeyExchangeSimplest(t *testing.T) {
t.Errorf("should be no siganature") t.Errorf("should be no siganature")
} }
if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { key2, err := responder.ConfirmInitiator(nil)
if err != nil {
t.Fatal(err)
}
if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }
@ -97,7 +110,10 @@ func TestSetPeerParameters(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() {
initiator.Destroy()
responder.Destroy()
}()
rA, err := initiator.InitKeyExchange(rand.Reader) rA, err := initiator.InitKeyExchange(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -124,17 +140,17 @@ func TestSetPeerParameters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s1, err := initiator.ConfirmResponder(rB, s2) key1, s1, err := initiator.ConfirmResponder(rB, s2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = responder.ConfirmInitiator(s1) key2, err := responder.ConfirmInitiator(s1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }
@ -153,7 +169,10 @@ func TestKeyExchange_SetPeerParameters(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() {
initiator.Destroy()
responder.Destroy()
}()
rA, err := initiator.InitKeyExchange(rand.Reader) rA, err := initiator.InitKeyExchange(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -174,17 +193,17 @@ func TestKeyExchange_SetPeerParameters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s1, err := initiator.ConfirmResponder(rB, s2) key1, s1, err := initiator.ConfirmResponder(rB, s2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = responder.ConfirmInitiator(s1) key2, err := responder.ConfirmInitiator(s1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }
@ -203,7 +222,10 @@ func TestKeyExchange_SetPeerParameters_ErrCase(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer func() {
initiator.Destroy()
responder.Destroy()
}()
rA, err := initiator.InitKeyExchange(rand.Reader) rA, err := initiator.InitKeyExchange(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -213,7 +235,7 @@ func TestKeyExchange_SetPeerParameters_ErrCase(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = initiator.ConfirmResponder(rB, s2) _, _, err = initiator.ConfirmResponder(rB, s2)
if err == nil { if err == nil {
t.Fatal(errors.New("expect call ConfirmResponder got a error, but not")) t.Fatal(errors.New("expect call ConfirmResponder got a error, but not"))
} }

View File

@ -469,7 +469,6 @@ type KeyExchange struct {
g1 *bn256.GT // internal state which will be used when compute the key and signature g1 *bn256.GT // internal state which will be used when compute the key and signature
g2 *bn256.GT // internal state which will be used when compute the key and signature g2 *bn256.GT // internal state which will be used when compute the key and signature
g3 *bn256.GT // internal state which will be used when compute the key and signature g3 *bn256.GT // internal state which will be used when compute the key and signature
key []byte // shared key will be used after key agreement
} }
// NewKeyExchange create one new KeyExchange object // NewKeyExchange create one new KeyExchange object
@ -483,9 +482,26 @@ func NewKeyExchange(priv *EncryptPrivateKey, uid, peerUID []byte, keyLen int, ge
return ke return ke
} }
// GetSharedKey return key after key agreement // Destroy clear all internal state and Ephemeral private/public keys
func (ke *KeyExchange) GetSharedKey() []byte { func (ke *KeyExchange) Destroy() {
return ke.key if ke.r != nil {
ke.r.SetInt64(0)
}
if ke.secret != nil {
ke.secret.Set(bn256.Gen1)
}
if ke.peerSecret != nil {
ke.peerSecret.Set(bn256.Gen1)
}
if ke.g1 != nil {
ke.g1.SetOne()
}
if ke.g2 != nil {
ke.g2.SetOne()
}
if ke.g3 != nil {
ke.g3.SetOne()
}
} }
func initKeyExchange(ke *KeyExchange, hid byte, r *big.Int) { func initKeyExchange(ke *KeyExchange, hid byte, r *big.Int) {
@ -529,7 +545,7 @@ func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
return hash.Sum(nil) return hash.Sum(nil)
} }
func (ke *KeyExchange) generateSharedKey(isResponder bool) { func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
var buffer []byte var buffer []byte
if isResponder { if isResponder {
buffer = append(buffer, ke.peerUID...) buffer = append(buffer, ke.peerUID...)
@ -546,8 +562,11 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) {
buffer = append(buffer, ke.g2.Marshal()...) buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...) buffer = append(buffer, ke.g3.Marshal()...)
key, _ := sm3.Kdf(buffer, ke.keyLength) key, ok := sm3.Kdf(buffer, ke.keyLength)
ke.key = key if !ok {
return nil, errors.New("sm9: internal error, kdf failed")
}
return key, nil
} }
func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) {
@ -565,8 +584,6 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b
ke.g3.ScalarMult(ke.g1, r) ke.g3.ScalarMult(ke.g1, r)
ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r)
ke.generateSharedKey(true)
if !ke.genSignature { if !ke.genSignature {
return ke.secret, nil, nil return ke.secret, nil, nil
} }
@ -584,9 +601,9 @@ func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1)
} }
// ConfirmResponder for initiator's step A5-A7 // ConfirmResponder for initiator's step A5-A7
func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) { func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, []byte, error) {
if !rB.IsOnCurve() { if !rB.IsOnCurve() {
return nil, errors.New("sm9: invalid responder's ephemeral public key") return nil, nil, errors.New("sm9: invalid responder's ephemeral public key")
} }
// step 5 // step 5
ke.peerSecret = rB ke.peerSecret = rB
@ -598,21 +615,26 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error)
if len(sB) > 0 { if len(sB) > 0 {
signature := ke.sign(false, 0x82) signature := ke.sign(false, 0x82)
if goSubtle.ConstantTimeCompare(signature, sB) != 1 { if goSubtle.ConstantTimeCompare(signature, sB) != 1 {
return nil, errors.New("sm9: invalid responder's signature") return nil, nil, errors.New("sm9: invalid responder's signature")
} }
} }
ke.generateSharedKey(false) key, err := ke.generateSharedKey(false)
if !ke.genSignature { if err != nil {
return nil, nil return nil, nil, err
} }
return ke.sign(false, 0x83), nil if !ke.genSignature {
return key, nil, nil
}
return key, ke.sign(false, 0x83), nil
} }
// ConfirmInitiator for responder's step B8 // ConfirmInitiator for responder's step B8
func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { func (ke *KeyExchange) ConfirmInitiator(s1 []byte) ([]byte, error) {
buffer := ke.sign(true, 0x83) if s1 != nil {
if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { buffer := ke.sign(true, 0x83)
return errors.New("sm9: invalid initiator's signature") if goSubtle.ConstantTimeCompare(buffer, s1) != 1 {
return nil, errors.New("sm9: invalid initiator's signature")
}
} }
return nil return ke.generateSharedKey(true)
} }

View File

@ -156,7 +156,10 @@ func TestKeyExchangeSample(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
responder := NewKeyExchange(userKey, userB, userA, 16, true) responder := NewKeyExchange(userKey, userB, userA, 16, true)
defer func() {
initiator.Destroy()
responder.Destroy()
}()
// A1-A4 // A1-A4
initKeyExchange(initiator, hid, bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8")) initKeyExchange(initiator, hid, bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8"))
@ -169,29 +172,30 @@ func TestKeyExchangeSample(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(responder.key) != expectedKey {
t.Errorf("not expected key %v\n", hex.EncodeToString(responder.key))
}
if hex.EncodeToString(sigB) != expectedSignatureB { if hex.EncodeToString(sigB) != expectedSignatureB {
t.Errorf("not expected signature B") t.Errorf("not expected signature B")
} }
// A5 -A8 // A5 -A8
sigA, err := initiator.ConfirmResponder(rB, sigB) key1, sigA, err := initiator.ConfirmResponder(rB, sigB)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(initiator.key) != expectedKey { if hex.EncodeToString(key1) != expectedKey {
t.Errorf("not expected key %v\n", hex.EncodeToString(initiator.key)) t.Errorf("not expected key %v\n", hex.EncodeToString(key1))
} }
if hex.EncodeToString(sigA) != expectedSignatureA { if hex.EncodeToString(sigA) != expectedSignatureA {
t.Errorf("not expected signature A") t.Errorf("not expected signature A")
} }
// B8 // B8
err = responder.ConfirmInitiator(sigA) key2, err := responder.ConfirmInitiator(sigA)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(key2) != expectedKey {
t.Errorf("not expected key %v\n", hex.EncodeToString(key2))
}
} }
func TestKeyExchange(t *testing.T) { func TestKeyExchange(t *testing.T) {
@ -214,7 +218,10 @@ func TestKeyExchange(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
responder := NewKeyExchange(userKey, userB, userA, 16, true) responder := NewKeyExchange(userKey, userB, userA, 16, true)
defer func() {
initiator.Destroy()
responder.Destroy()
}()
// A1-A4 // A1-A4
rA, err := initiator.InitKeyExchange(rand.Reader, hid) rA, err := initiator.InitKeyExchange(rand.Reader, hid)
if err != nil { if err != nil {
@ -228,18 +235,18 @@ func TestKeyExchange(t *testing.T) {
} }
// A5 -A8 // A5 -A8
sigA, err := initiator.ConfirmResponder(rB, sigB) key1, sigA, err := initiator.ConfirmResponder(rB, sigB)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// B8 // B8
err = responder.ConfirmInitiator(sigA) key2, err := responder.ConfirmInitiator(sigA)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }
@ -264,7 +271,10 @@ func TestKeyExchangeWithoutSignature(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
responder := NewKeyExchange(userKey, userB, userA, 16, false) responder := NewKeyExchange(userKey, userB, userA, 16, false)
defer func() {
initiator.Destroy()
responder.Destroy()
}()
// A1-A4 // A1-A4
rA, err := initiator.InitKeyExchange(rand.Reader, hid) rA, err := initiator.InitKeyExchange(rand.Reader, hid)
if err != nil { if err != nil {
@ -281,7 +291,7 @@ func TestKeyExchangeWithoutSignature(t *testing.T) {
} }
// A5 -A8 // A5 -A8
sigA, err := initiator.ConfirmResponder(rB, sigB) key1, sigA, err := initiator.ConfirmResponder(rB, sigB)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -289,7 +299,12 @@ func TestKeyExchangeWithoutSignature(t *testing.T) {
t.Errorf("should no signature") t.Errorf("should no signature")
} }
if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { key2, err := responder.ConfirmInitiator(nil)
if err != nil {
t.Fatal(err)
}
if hex.EncodeToString(key1) != hex.EncodeToString(key2) {
t.Errorf("got different key") t.Errorf("got different key")
} }
} }