diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index 8958bb8..d12cb9a 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -26,12 +26,47 @@ type KeyExchange struct { w2 *big.Int // internal state which will be used when compute the key and signature, 2^w w2Minus1 *big.Int // internal state which will be used when compute the key and signature, 2^w – 1 v *ecdsa.PublicKey // internal state which will be used when compute the key and signature, u/v - key []byte // shared key will be used after key agreement } -// GetSharedKey return shared key after key agreement -func (ke *KeyExchange) GetSharedKey() []byte { - return ke.key +// Destroy clear all internal state and Ephemeral private/public keys. +func (ke *KeyExchange) Destroy() { + if ke.z != nil { + for v := range ke.z { + ke.z[v] = 0 + } + } + if ke.peerZ != nil { + for v := range ke.peerZ { + ke.peerZ[v] = 0 + } + } + if ke.r != nil { + ke.r.SetInt64(0) + } + if ke.secret != nil { + if ke.secret.X != nil { + ke.secret.X.SetInt64(0) + } + if ke.secret.Y != nil { + ke.secret.Y.SetInt64(0) + } + } + if ke.peerSecret != nil { + if ke.peerSecret.X != nil { + ke.peerSecret.X.SetInt64(0) + } + if ke.peerSecret.Y != nil { + ke.peerSecret.Y.SetInt64(0) + } + } + if ke.v != nil { + if ke.v.X != nil { + ke.v.X.SetInt64(0) + } + if ke.v.Y != nil { + ke.v.Y.SetInt64(0) + } + } } // NewKeyExchange create one new KeyExchange object @@ -150,7 +185,7 @@ func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte { return hash.Sum(nil) } -func (ke *KeyExchange) generateSharedKey(isResponder bool) { +func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { var buffer []byte buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...) buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...) @@ -161,39 +196,54 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) { buffer = append(buffer, ke.z...) buffer = append(buffer, ke.peerZ...) } - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key + key, ok := sm3.Kdf(buffer, ke.keyLength) + if !ok { + return nil, errors.New("sm2: internal error, kdf failed") + } + return key, nil +} + +// avf is the associative value function. +func (ke *KeyExchange) avf(x *big.Int) *big.Int { + t := (&big.Int{}).And(ke.w2Minus1, x) + t.Add(ke.w2, t) + return t +} + +func (ke *KeyExchange) implicitSig() []byte { + // Calculate x2` + t := ke.avf(ke.secret.X) + + // Calculate tB + t.Mul(t, ke.r) + t.Add(t, ke.privateKey.D) + t.Mod(t, ke.privateKey.Params().N) + return t.Bytes() +} + +func (ke *KeyExchange) basePoint() (*big.Int, *big.Int) { + // x1` = 2^w + (x & (2^w – 1)) + x1 := ke.avf(ke.peerSecret.X) + // Point(x3, y3) = peerPub + [x1](peerSecret) + x, y := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x1.Bytes()) + x, y = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x, y) + return x, y } func respondKeyExchange(ke *KeyExchange, r *big.Int) (*ecdsa.PublicKey, []byte, error) { // secret = RB = [r]G ke.secret.X, ke.secret.Y = ke.privateKey.ScalarBaseMult(r.Bytes()) ke.r = r - // Calculate x2` - t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) - t.Add(ke.w2, t) - - // Calculate tB - t.Mul(t, ke.r) - t.Add(t, ke.privateKey.D) - t.Mod(t, ke.privateKey.Params().N) - - // x1` = 2^w + (x & (2^w – 1)) - x1 := (&big.Int{}).And(ke.w2Minus1, ke.peerSecret.X) - x1.Add(ke.w2, x1) // Point(x3, y3) = peerPub + [x1](peerSecret) - x3, y3 := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x1.Bytes()) - x3, y3 = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x3, y3) + x, y := ke.basePoint() // V = [h*tB](Point(x3, y3)) - ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x3, y3, t.Bytes()) + ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x, y, ke.implicitSig()) if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { return nil, nil, errors.New("sm2: key exchange failed, V is infinity point") } - ke.generateSharedKey(true) - if !ke.genSignature { return ke.secret, nil, nil } @@ -220,59 +270,52 @@ func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, rA *ecdsa.PublicKey) (* return respondKeyExchange(ke, r) } -// ConfirmResponder for initiator's step A4-A10, returns optional signature. +// ConfirmResponder for initiator's step A4-A10, returns keying data and optional signature. // // It will check if there are peer's public key and validate the peer's Ephemeral Public Key. // // If the peer's signature is not empty, then it will also validate the peer's // signature and return generated signature depends on KeyExchange.genSignature value. -func (ke *KeyExchange) ConfirmResponder(rB *ecdsa.PublicKey, sB []byte) ([]byte, error) { +func (ke *KeyExchange) ConfirmResponder(rB *ecdsa.PublicKey, sB []byte) ([]byte, []byte, error) { if ke.peerPub == nil { - return nil, errors.New("sm2: no peer public key given") + return nil, nil, errors.New("sm2: no peer public key given") } if !ke.privateKey.IsOnCurve(rB.X, rB.Y) { - return nil, errors.New("sm2: invalid responder's ephemeral public key") + return nil, nil, errors.New("sm2: invalid responder's ephemeral public key") } ke.peerSecret = rB - // Calculate tA - t := (&big.Int{}).And(ke.w2Minus1, ke.secret.X) - t.Add(ke.w2, t) - t.Mul(t, ke.r) - t.Add(t, ke.privateKey.D) - t.Mod(t, ke.privateKey.Params().N) - - // x2` = 2^w + (x & (2^w – 1)) - x2 := (&big.Int{}).And(ke.w2Minus1, ke.peerSecret.X) - x2.Add(ke.w2, x2) - - // Point(x3, y3) = peerPub + [x1](peerSecret) - x3, y3 := ke.privateKey.ScalarMult(ke.peerSecret.X, ke.peerSecret.Y, x2.Bytes()) - x3, y3 = ke.privateKey.Add(ke.peerPub.X, ke.peerPub.Y, x3, y3) + x, y := ke.basePoint() // U = [h*tA](Point(x3, y3)) - ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x3, y3, t.Bytes()) + ke.v.X, ke.v.Y = ke.privateKey.ScalarMult(x, y, ke.implicitSig()) if ke.v.X.Sign() == 0 && ke.v.Y.Sign() == 0 { - return nil, errors.New("sm2: key exchange failed, U is infinity point") + return nil, nil, errors.New("sm2: key exchange failed, U is infinity point") } - ke.generateSharedKey(false) if len(sB) > 0 { buffer := ke.sign(false, 0x02) if subtle.ConstantTimeCompare(buffer, sB) != 1 { - return nil, errors.New("sm2: invalid responder's signature") + return nil, nil, errors.New("sm2: invalid responder's signature") } } - if !ke.genSignature { - return nil, nil + key, err := ke.generateSharedKey(false) + if err != nil { + return nil, nil, err } - return ke.sign(false, 0x03), nil + + if !ke.genSignature { + return key, nil, nil + } + return key, ke.sign(false, 0x03), nil } // ConfirmInitiator for responder's step B10 -func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { - buffer := ke.sign(true, 0x03) - if subtle.ConstantTimeCompare(buffer, s1) != 1 { - return errors.New("sm2: invalid initiator's signature") +func (ke *KeyExchange) ConfirmInitiator(s1 []byte) ([]byte, error) { + if s1 != nil { + buffer := ke.sign(true, 0x03) + if subtle.ConstantTimeCompare(buffer, s1) != 1 { + return nil, errors.New("sm2: invalid initiator's signature") + } } - return nil + return ke.generateSharedKey(true) } diff --git a/sm2/sm2_keyexchange_test.go b/sm2/sm2_keyexchange_test.go index 047094a..f3f7369 100644 --- a/sm2/sm2_keyexchange_test.go +++ b/sm2/sm2_keyexchange_test.go @@ -20,6 +20,10 @@ func TestKeyExchangeSample(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { + initiator.Destroy() + responder.Destroy() + }() rA, err := initiator.InitKeyExchange(rand.Reader) if err != nil { t.Fatal(err) @@ -30,17 +34,17 @@ func TestKeyExchangeSample(t *testing.T) { t.Fatal(err) } - s1, err := initiator.ConfirmResponder(rB, s2) + key1, s1, err := initiator.ConfirmResponder(rB, s2) if err != nil { t.Fatal(err) } - err = responder.ConfirmInitiator(s1) + key2, err := responder.ConfirmInitiator(s1) if err != nil { t.Fatal(err) } - if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } } @@ -56,6 +60,10 @@ func TestKeyExchangeSimplest(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { + initiator.Destroy() + responder.Destroy() + }() rA, err := initiator.InitKeyExchange(rand.Reader) if err != nil { t.Fatal(err) @@ -69,7 +77,7 @@ func TestKeyExchangeSimplest(t *testing.T) { t.Errorf("should be no siganature") } - s1, err := initiator.ConfirmResponder(rB, nil) + key1, s1, err := initiator.ConfirmResponder(rB, nil) if err != nil { t.Fatal(err) } @@ -77,7 +85,12 @@ func TestKeyExchangeSimplest(t *testing.T) { t.Errorf("should be no siganature") } - if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { + key2, err := responder.ConfirmInitiator(nil) + if err != nil { + t.Fatal(err) + } + + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } } @@ -97,7 +110,10 @@ func TestSetPeerParameters(t *testing.T) { if err != nil { t.Fatal(err) } - + defer func() { + initiator.Destroy() + responder.Destroy() + }() rA, err := initiator.InitKeyExchange(rand.Reader) if err != nil { t.Fatal(err) @@ -124,17 +140,17 @@ func TestSetPeerParameters(t *testing.T) { t.Fatal(err) } - s1, err := initiator.ConfirmResponder(rB, s2) + key1, s1, err := initiator.ConfirmResponder(rB, s2) if err != nil { t.Fatal(err) } - err = responder.ConfirmInitiator(s1) + key2, err := responder.ConfirmInitiator(s1) if err != nil { t.Fatal(err) } - if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } } @@ -153,7 +169,10 @@ func TestKeyExchange_SetPeerParameters(t *testing.T) { if err != nil { t.Fatal(err) } - + defer func() { + initiator.Destroy() + responder.Destroy() + }() rA, err := initiator.InitKeyExchange(rand.Reader) if err != nil { t.Fatal(err) @@ -174,17 +193,17 @@ func TestKeyExchange_SetPeerParameters(t *testing.T) { t.Fatal(err) } - s1, err := initiator.ConfirmResponder(rB, s2) + key1, s1, err := initiator.ConfirmResponder(rB, s2) if err != nil { t.Fatal(err) } - err = responder.ConfirmInitiator(s1) + key2, err := responder.ConfirmInitiator(s1) if err != nil { t.Fatal(err) } - if hex.EncodeToString(initiator.key) != hex.EncodeToString(responder.key) { + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } } @@ -203,7 +222,10 @@ func TestKeyExchange_SetPeerParameters_ErrCase(t *testing.T) { if err != nil { t.Fatal(err) } - + defer func() { + initiator.Destroy() + responder.Destroy() + }() rA, err := initiator.InitKeyExchange(rand.Reader) if err != nil { t.Fatal(err) @@ -213,7 +235,7 @@ func TestKeyExchange_SetPeerParameters_ErrCase(t *testing.T) { t.Fatal(err) } - _, err = initiator.ConfirmResponder(rB, s2) + _, _, err = initiator.ConfirmResponder(rB, s2) if err == nil { t.Fatal(errors.New("expect call ConfirmResponder got a error, but not")) } diff --git a/sm9/sm9.go b/sm9/sm9.go index 4000011..535ea49 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -469,7 +469,6 @@ type KeyExchange struct { g1 *bn256.GT // internal state which will be used when compute the key and signature g2 *bn256.GT // internal state which will be used when compute the key and signature g3 *bn256.GT // internal state which will be used when compute the key and signature - key []byte // shared key will be used after key agreement } // NewKeyExchange create one new KeyExchange object @@ -483,9 +482,26 @@ func NewKeyExchange(priv *EncryptPrivateKey, uid, peerUID []byte, keyLen int, ge return ke } -// GetSharedKey return key after key agreement -func (ke *KeyExchange) GetSharedKey() []byte { - return ke.key +// Destroy clear all internal state and Ephemeral private/public keys +func (ke *KeyExchange) Destroy() { + if ke.r != nil { + ke.r.SetInt64(0) + } + if ke.secret != nil { + ke.secret.Set(bn256.Gen1) + } + if ke.peerSecret != nil { + ke.peerSecret.Set(bn256.Gen1) + } + if ke.g1 != nil { + ke.g1.SetOne() + } + if ke.g2 != nil { + ke.g2.SetOne() + } + if ke.g3 != nil { + ke.g3.SetOne() + } } func initKeyExchange(ke *KeyExchange, hid byte, r *big.Int) { @@ -529,7 +545,7 @@ func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte { return hash.Sum(nil) } -func (ke *KeyExchange) generateSharedKey(isResponder bool) { +func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { var buffer []byte if isResponder { buffer = append(buffer, ke.peerUID...) @@ -546,8 +562,11 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) { buffer = append(buffer, ke.g2.Marshal()...) buffer = append(buffer, ke.g3.Marshal()...) - key, _ := sm3.Kdf(buffer, ke.keyLength) - ke.key = key + key, ok := sm3.Kdf(buffer, ke.keyLength) + if !ok { + return nil, errors.New("sm9: internal error, kdf failed") + } + return key, nil } func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { @@ -565,8 +584,6 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b ke.g3.ScalarMult(ke.g1, r) ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) - ke.generateSharedKey(true) - if !ke.genSignature { return ke.secret, nil, nil } @@ -584,9 +601,9 @@ func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1) } // ConfirmResponder for initiator's step A5-A7 -func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) { +func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, []byte, error) { if !rB.IsOnCurve() { - return nil, errors.New("sm9: invalid responder's ephemeral public key") + return nil, nil, errors.New("sm9: invalid responder's ephemeral public key") } // step 5 ke.peerSecret = rB @@ -598,21 +615,26 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) if len(sB) > 0 { signature := ke.sign(false, 0x82) if goSubtle.ConstantTimeCompare(signature, sB) != 1 { - return nil, errors.New("sm9: invalid responder's signature") + return nil, nil, errors.New("sm9: invalid responder's signature") } } - ke.generateSharedKey(false) - if !ke.genSignature { - return nil, nil + key, err := ke.generateSharedKey(false) + if err != nil { + return nil, nil, err } - return ke.sign(false, 0x83), nil + if !ke.genSignature { + return key, nil, nil + } + return key, ke.sign(false, 0x83), nil } // ConfirmInitiator for responder's step B8 -func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { - buffer := ke.sign(true, 0x83) - if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { - return errors.New("sm9: invalid initiator's signature") +func (ke *KeyExchange) ConfirmInitiator(s1 []byte) ([]byte, error) { + if s1 != nil { + buffer := ke.sign(true, 0x83) + if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { + return nil, errors.New("sm9: invalid initiator's signature") + } } - return nil + return ke.generateSharedKey(true) } diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 9873cfa..c59f9cb 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -156,7 +156,10 @@ func TestKeyExchangeSample(t *testing.T) { t.Fatal(err) } responder := NewKeyExchange(userKey, userB, userA, 16, true) - + defer func() { + initiator.Destroy() + responder.Destroy() + }() // A1-A4 initKeyExchange(initiator, hid, bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8")) @@ -169,29 +172,30 @@ func TestKeyExchangeSample(t *testing.T) { if err != nil { t.Fatal(err) } - if hex.EncodeToString(responder.key) != expectedKey { - t.Errorf("not expected key %v\n", hex.EncodeToString(responder.key)) - } + if hex.EncodeToString(sigB) != expectedSignatureB { t.Errorf("not expected signature B") } // A5 -A8 - sigA, err := initiator.ConfirmResponder(rB, sigB) + key1, sigA, err := initiator.ConfirmResponder(rB, sigB) if err != nil { t.Fatal(err) } - if hex.EncodeToString(initiator.key) != expectedKey { - t.Errorf("not expected key %v\n", hex.EncodeToString(initiator.key)) + if hex.EncodeToString(key1) != expectedKey { + t.Errorf("not expected key %v\n", hex.EncodeToString(key1)) } if hex.EncodeToString(sigA) != expectedSignatureA { t.Errorf("not expected signature A") } // B8 - err = responder.ConfirmInitiator(sigA) + key2, err := responder.ConfirmInitiator(sigA) if err != nil { t.Fatal(err) } + if hex.EncodeToString(key2) != expectedKey { + t.Errorf("not expected key %v\n", hex.EncodeToString(key2)) + } } func TestKeyExchange(t *testing.T) { @@ -214,7 +218,10 @@ func TestKeyExchange(t *testing.T) { t.Fatal(err) } responder := NewKeyExchange(userKey, userB, userA, 16, true) - + defer func() { + initiator.Destroy() + responder.Destroy() + }() // A1-A4 rA, err := initiator.InitKeyExchange(rand.Reader, hid) if err != nil { @@ -228,18 +235,18 @@ func TestKeyExchange(t *testing.T) { } // A5 -A8 - sigA, err := initiator.ConfirmResponder(rB, sigB) + key1, sigA, err := initiator.ConfirmResponder(rB, sigB) if err != nil { t.Fatal(err) } // B8 - err = responder.ConfirmInitiator(sigA) + key2, err := responder.ConfirmInitiator(sigA) if err != nil { t.Fatal(err) } - if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } } @@ -264,7 +271,10 @@ func TestKeyExchangeWithoutSignature(t *testing.T) { t.Fatal(err) } responder := NewKeyExchange(userKey, userB, userA, 16, false) - + defer func() { + initiator.Destroy() + responder.Destroy() + }() // A1-A4 rA, err := initiator.InitKeyExchange(rand.Reader, hid) if err != nil { @@ -281,7 +291,7 @@ func TestKeyExchangeWithoutSignature(t *testing.T) { } // A5 -A8 - sigA, err := initiator.ConfirmResponder(rB, sigB) + key1, sigA, err := initiator.ConfirmResponder(rB, sigB) if err != nil { t.Fatal(err) } @@ -289,7 +299,12 @@ func TestKeyExchangeWithoutSignature(t *testing.T) { t.Errorf("should no signature") } - if hex.EncodeToString(initiator.GetSharedKey()) != hex.EncodeToString(responder.GetSharedKey()) { + key2, err := responder.ConfirmInitiator(nil) + if err != nil { + t.Fatal(err) + } + + if hex.EncodeToString(key1) != hex.EncodeToString(key2) { t.Errorf("got different key") } }