From c32a9849f89772bafcc6147e21abfb2b951c4617 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Thu, 13 Mar 2025 16:50:28 +0800 Subject: [PATCH] sm9: refactoring #314 --- pkcs8/example_test.go | 3 +- sm9/example_test.go | 22 ++---- sm9/sm9.go | 63 ++++++++++++--- sm9/sm9_key.go | 176 ++++++++++++++++++++++-------------------- sm9/sm9_key_test.go | 37 ++++----- sm9/sm9_test.go | 103 ++++++++++-------------- smx509/pkcs8.go | 28 +------ 7 files changed, 214 insertions(+), 218 deletions(-) diff --git a/pkcs8/example_test.go b/pkcs8/example_test.go index 2531892..e472b8a 100644 --- a/pkcs8/example_test.go +++ b/pkcs8/example_test.go @@ -271,8 +271,7 @@ func ExampleMarshalPrivateKey_withoutPasswordSM9MasterSignKey() { var b cryptobyte.Builder b.AddASN1BigInt(new(big.Int).SetBytes(kb)) kb, _ = b.Bytes() - testkey := new(sm9.SignMasterPrivateKey) - err := testkey.UnmarshalASN1(kb) + testkey, err := sm9.UnmarshalSignMasterPrivateKeyASN1(kb) if err != nil { panic(err) } diff --git a/sm9/example_test.go b/sm9/example_test.go index 5def299..9573cf6 100644 --- a/sm9/example_test.go +++ b/sm9/example_test.go @@ -17,10 +17,9 @@ func ExampleSignPrivateKey_Sign() { var b cryptobyte.Builder b.AddASN1BigInt(new(big.Int).SetBytes(kb)) kb, _ = b.Bytes() - masterkey := new(sm9.SignMasterPrivateKey) - err := masterkey.UnmarshalASN1(kb) + masterkey, err := sm9.UnmarshalSignMasterPrivateKeyASN1(kb) if err != nil { - fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) + fmt.Fprintf(os.Stderr, "Error from UnmarshalSignMasterPrivateKeyASN1: %s\n", err) return } hid := byte(0x01) @@ -46,9 +45,8 @@ func ExampleSignPrivateKey_Sign() { func ExampleVerifyASN1() { // get master public key, can be from pem - masterPubKey := new(sm9.SignMasterPublicKey) keyBytes, _ := hex.DecodeString("03818200049f64080b3084f733e48aff4b41b565011ce0711c5e392cfb0ab1b6791b94c40829dba116152d1f786ce843ed24a3b573414d2177386a92dd8f14d65696ea5e3269850938abea0112b57329f447e3a0cbad3e2fdb1a77f335e89e1408d0ef1c2541e00a53dda532da1a7ce027b7a46f741006e85f5cdff0730e75c05fb4e3216d") - err := masterPubKey.UnmarshalASN1(keyBytes) + masterPubKey, err := sm9.UnmarshalSignMasterPublicKeyASN1(keyBytes) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return @@ -67,7 +65,7 @@ func ExampleSignMasterPublicKey_Verify() { // get master public key, can be from pem masterPubKey := new(sm9.SignMasterPublicKey) keyBytes, _ := hex.DecodeString("03818200049f64080b3084f733e48aff4b41b565011ce0711c5e392cfb0ab1b6791b94c40829dba116152d1f786ce843ed24a3b573414d2177386a92dd8f14d65696ea5e3269850938abea0112b57329f447e3a0cbad3e2fdb1a77f335e89e1408d0ef1c2541e00a53dda532da1a7ce027b7a46f741006e85f5cdff0730e75c05fb4e3216d") - err := masterPubKey.UnmarshalASN1(keyBytes) + masterPubKey, err := sm9.UnmarshalSignMasterPublicKeyASN1(keyBytes) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return @@ -85,8 +83,7 @@ func ExampleSignMasterPublicKey_Verify() { func ExampleEncryptPrivateKey_UnwrapKey() { // real user encrypt private key should be from secret storage, e.g. password protected pkcs8 file kb, _ := hex.DecodeString("038182000494736acd2c8c8796cc4785e938301a139a059d3537b6414140b2d31eecf41683115bae85f5d8bc6c3dbd9e5342979acccf3c2f4f28420b1cb4f8c0b59a19b1587aa5e47570da7600cd760a0cf7beaf71c447f3844753fe74fa7ba92ca7d3b55f27538a62e7f7bfb51dce08704796d94c9d56734f119ea44732b50e31cdeb75c1") - userKey := new(sm9.EncryptPrivateKey) - err := userKey.UnmarshalASN1(kb) + userKey, err := sm9.UnmarshalEncryptPrivateKeyASN1(kb) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return @@ -104,9 +101,8 @@ func ExampleEncryptPrivateKey_UnwrapKey() { func ExampleEncryptMasterPublicKey_WrapKey() { // get master public key, can be from pem - masterPubKey := new(sm9.EncryptMasterPublicKey) keyBytes, _ := hex.DecodeString("03420004787ed7b8a51f3ab84e0a66003f32da5c720b17eca7137d39abc66e3c80a892ff769de61791e5adc4b9ff85a31354900b202871279a8c49dc3f220f644c57a7b1") - err := masterPubKey.UnmarshalASN1(keyBytes) + masterPubKey, err := sm9.UnmarshalEncryptMasterPublicKeyASN1(keyBytes) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return @@ -127,8 +123,7 @@ func ExampleEncryptMasterPublicKey_WrapKey() { func ExampleEncryptPrivateKey_Decrypt() { // real user encrypt private key should be from secret storage, e.g. password protected pkcs8 file kb, _ := hex.DecodeString("038182000494736acd2c8c8796cc4785e938301a139a059d3537b6414140b2d31eecf41683115bae85f5d8bc6c3dbd9e5342979acccf3c2f4f28420b1cb4f8c0b59a19b1587aa5e47570da7600cd760a0cf7beaf71c447f3844753fe74fa7ba92ca7d3b55f27538a62e7f7bfb51dce08704796d94c9d56734f119ea44732b50e31cdeb75c1") - userKey := new(sm9.EncryptPrivateKey) - err := userKey.UnmarshalASN1(kb) + userKey, err := sm9.UnmarshalEncryptPrivateKeyASN1(kb) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return @@ -146,9 +141,8 @@ func ExampleEncryptPrivateKey_Decrypt() { func ExampleEncryptMasterPublicKey_Encrypt() { // get master public key, can be from pem - masterPubKey := new(sm9.EncryptMasterPublicKey) keyBytes, _ := hex.DecodeString("03420004787ed7b8a51f3ab84e0a66003f32da5c720b17eca7137d39abc66e3c80a892ff769de61791e5adc4b9ff85a31354900b202871279a8c49dc3f220f644c57a7b1") - err := masterPubKey.UnmarshalASN1(keyBytes) + masterPubKey, err := sm9.UnmarshalEncryptMasterPublicKeyASN1(keyBytes) if err != nil { fmt.Fprintf(os.Stderr, "Error from UnmarshalASN1: %s\n", err) return diff --git a/sm9/sm9.go b/sm9/sm9.go index ec8d0ef..af27353 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -349,37 +349,82 @@ func (priv *EncryptPrivateKey) DecryptASN1(uid, ciphertext []byte) ([]byte, erro return DecryptASN1(priv, uid, ciphertext) } -// KeyExchange represents key exchange struct, include internal stat in whole key exchange flow. +// KeyExchange defines an interface for key exchange protocols. +// It provides methods for initializing, responding, and confirming key exchanges. +// +// InitKeyExchange initializes the key exchange process. +// It takes a random number generator and a byte identifier as input, and returns +// the initial data for the key exchange and an error, if any. +// +// RespondKeyExchange responds to an initiated key exchange. +// It takes a random number generator, a byte identifier, and the peer's initial data +// as input, and returns the response data, additional data for confirmation, and an error, if any. +// +// ConfirmResponder confirms the key exchange from the responder's side. +// It takes the responder's response data and additional data as input, and returns +// the confirmation data and an error, if any. +// +// ConfirmInitiator confirms the key exchange from the initiator's side. +// It takes the peer's confirmation data as input, and returns the final confirmation data +// and an error, if any. +// KeyExchange defines an interface for key exchange operations. +// It provides methods to initialize, respond, and confirm key exchanges, +// as well as a method to destroy the key exchange instance. +type KeyExchange interface { + // Destroy cleans up any resources associated with the key exchange instance. + Destroy() + + // InitKeyExchange initializes the key exchange process. + // It takes a random number generator and a byte identifier as input, + // and returns the initial data for the key exchange or an error. + InitKeyExchange(rand io.Reader, hid byte) ([]byte, error) + + // RespondKeyExchange responds to an initiated key exchange. + // It takes a random number generator, a byte identifier, and the peer's initial data as input, + // and returns the response data, additional data, or an error. + RespondKeyExchange(rand io.Reader, hid byte, peerData []byte) ([]byte, []byte, error) + + // ConfirmResponder confirms the responder's part of the key exchange. + // It takes the responder's response data and additional data as input, + // and returns the confirmation data or an error. + ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) + + // ConfirmInitiator confirms the initiator's part of the key exchange. + // It takes the peer's data as input and returns the confirmation data or an error. + ConfirmInitiator(peerData []byte) ([]byte, error) +} + +// keyExchange represents key exchange struct, include internal stat in whole key exchange flow. // Initiator's flow will be: NewKeyExchange -> InitKeyExchange -> transmission -> ConfirmResponder // Responder's flow will be: NewKeyExchange -> waiting ... -> RepondKeyExchange -> transmission -> ConfirmInitiator -type KeyExchange struct { +type keyExchange struct { ke *sm9.KeyExchange } -func (priv *EncryptPrivateKey) NewKeyExchange(uid, peerUID []byte, keyLen int, genSignature bool) *KeyExchange { - return &KeyExchange{ke: priv.privateKey.NewKeyExchange(uid, peerUID, keyLen, genSignature)} +func (priv *EncryptPrivateKey) NewKeyExchange(uid, peerUID []byte, keyLen int, genSignature bool) *keyExchange { + return &keyExchange{ke: priv.privateKey.NewKeyExchange(uid, peerUID, keyLen, genSignature)} } -func (ke *KeyExchange) Destroy() { +func (ke *keyExchange) Destroy() { ke.ke.Destroy() } // InitKeyExchange generates random with responder uid, for initiator's step A1-A4 -func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) ([]byte, error) { +func (ke *keyExchange) InitKeyExchange(rand io.Reader, hid byte) ([]byte, error) { return ke.ke.InitKeyExchange(rand, hid) } // RespondKeyExchange when responder receive rA, for responder's step B1-B7 -func (ke *KeyExchange) RespondKeyExchange(rand io.Reader, hid byte, peerData []byte) ([]byte, []byte, error) { +func (ke *keyExchange) RespondKeyExchange(rand io.Reader, hid byte, peerData []byte) ([]byte, []byte, error) { return ke.ke.RespondKeyExchange(rand, hid, peerData) } // ConfirmResponder for initiator's step A5-A7 -func (ke *KeyExchange) ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) { +func (ke *keyExchange) ConfirmResponder(rB, sB []byte) ([]byte, []byte, error) { return ke.ke.ConfirmResponder(rB, sB) } // ConfirmInitiator for responder's step B8 -func (ke *KeyExchange) ConfirmInitiator(peerData []byte) ([]byte, error) { +func (ke *keyExchange) ConfirmInitiator(peerData []byte) ([]byte, error) { return ke.ke.ConfirmInitiator(peerData) } diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index 2d2ae53..5568890 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -72,8 +72,8 @@ func (master *SignMasterPrivateKey) MarshalASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalASN1 unmarsal der data to sign master private key -func (master *SignMasterPrivateKey) UnmarshalASN1(der []byte) error { +// UnmarshalSignMasterPrivateKeyASN1 unmarsal der data to sign master private key +func UnmarshalSignMasterPrivateKeyASN1(der []byte) (*SignMasterPrivateKey, error) { input := cryptobyte.String(der) d := &big.Int{} var inner cryptobyte.String @@ -83,20 +83,21 @@ func (master *SignMasterPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1(&inner, cryptobyte_asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1Integer(d) { - return errors.New("sm9: invalid sign master private key asn1 data") + return nil, errors.New("sm9: invalid sign master private key asn1 data") } // Just parse it, didn't validate it if !inner.Empty() && (!inner.ReadASN1BitStringAsBytes(&pubBytes) || !inner.Empty()) { - return errors.New("sm9: invalid sign master public key asn1 data") + return nil, errors.New("sm9: invalid sign master public key asn1 data") } } else if !input.ReadASN1Integer(d) || !input.Empty() { - return errors.New("sm9: invalid sign master private key asn1 data") + return nil, errors.New("sm9: invalid sign master private key asn1 data") } - master.privateKey, err = sm9.NewSignMasterPrivateKey(d.Bytes()) + + privateKey, err := sm9.NewSignMasterPrivateKey(d.Bytes()) if err != nil { - return err + return nil, err } - return nil + return &SignMasterPrivateKey{privateKey: privateKey}, nil } // GenerateUserKey generate an user dsa key. @@ -141,16 +142,16 @@ func (pub *SignMasterPublicKey) MarshalCompressedASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalRaw unmarsal raw bytes data to sign master public key -func (pub *SignMasterPublicKey) UnmarshalRaw(bytes []byte) error { - if pub.publicKey == nil { - pub.publicKey = new(sm9.SignMasterPublicKey) - } - return pub.publicKey.UnmarshalRaw(bytes) +// UnmarshalSignMasterPublicKeyRaw unmarsal raw bytes data to sign master public key +func UnmarshalSignMasterPublicKeyRaw(bytes []byte) (pub *SignMasterPublicKey, err error) { + pub = new(SignMasterPublicKey) + pub.publicKey = new(sm9.SignMasterPublicKey) + err = pub.publicKey.UnmarshalRaw(bytes) + return } -// UnmarshalASN1 unmarsal der data to sign master public key -func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { +// UnmarshalSignMasterPublicKeyASN1 unmarsal der data to sign master public key +func UnmarshalSignMasterPublicKeyASN1(der []byte) (*SignMasterPublicKey, error) { var bytes []byte var inner cryptobyte.String input := cryptobyte.String(der) @@ -159,21 +160,21 @@ func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { !input.Empty() || !inner.ReadASN1BitStringAsBytes(&bytes) || !inner.Empty() { - return errors.New("sm9: invalid sign master public key asn1 data") + return nil, errors.New("sm9: invalid sign master public key asn1 data") } } else if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { - return errors.New("sm9: invalid sign master public key asn1 data") + return nil, errors.New("sm9: invalid sign master public key asn1 data") } - return pub.UnmarshalRaw(bytes) + return UnmarshalSignMasterPublicKeyRaw(bytes) } -// ParseFromPEM just for GMSSL, there are no Algorithm pkix.AlgorithmIdentifier -func (pub *SignMasterPublicKey) ParseFromPEM(data []byte) error { +// ParseSignMasterPublicKeyPEM just for GMSSL, there are no Algorithm pkix.AlgorithmIdentifier +func ParseSignMasterPublicKeyPEM(data []byte) (*SignMasterPublicKey, error) { block, _ := pem.Decode([]byte(data)) if block == nil { - return errors.New("sm9: failed to parse PEM block") + return nil, errors.New("sm9: failed to parse PEM block") } - return pub.UnmarshalASN1(block.Bytes) + return UnmarshalSignMasterPublicKeyASN1(block.Bytes) } func (priv *SignPrivateKey) Equal(x *SignPrivateKey) bool { @@ -189,8 +190,8 @@ func (priv *SignPrivateKey) MasterPublic() *SignMasterPublicKey { return &SignMasterPublicKey{priv.privateKey.MasterPublic()} } -// SetMasterPublicKey bind the sign master public key to it. -func (priv *SignPrivateKey) SetMasterPublicKey(pub *SignMasterPublicKey) { +// setMasterPublicKey bind the sign master public key to it. +func (priv *SignPrivateKey) setMasterPublicKey(pub *SignMasterPublicKey) { priv.privateKey.SetMasterPublicKey(pub.publicKey) } @@ -210,18 +211,21 @@ func (priv *SignPrivateKey) MarshalCompressedASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalRaw unmarsal raw bytes data to sign user private key +// UnmarshalSignPrivateKeyRaw unmarsal raw bytes data to sign user private key // Note, priv's SignMasterPublicKey should be handled separately. -func (priv *SignPrivateKey) UnmarshalRaw(bytes []byte) error { - if priv.privateKey == nil { - priv.privateKey = new(sm9.SignPrivateKey) +func UnmarshalSignPrivateKeyRaw(bytes []byte) (*SignPrivateKey, error) { + priv := new(SignPrivateKey) + priv.privateKey = new(sm9.SignPrivateKey) + err := priv.privateKey.UnmarshalRaw(bytes) + if err != nil { + return nil, err } - return priv.privateKey.UnmarshalRaw(bytes) + return priv, nil } -// UnmarshalASN1 unmarsal der data to sign user private key +// UnmarshalSignPrivateKeyASN1 unmarsal der data to sign user private key // Note, priv's SignMasterPublicKey should be handled separately. -func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { +func UnmarshalSignPrivateKeyASN1(der []byte) (*SignPrivateKey, error) { var bytes []byte var pubBytes []byte var inner cryptobyte.String @@ -230,27 +234,27 @@ func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1(&inner, cryptobyte_asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1BitStringAsBytes(&bytes) { - return errors.New("sm9: invalid sign user private key asn1 data") + return nil, errors.New("sm9: invalid sign user private key asn1 data") } if !inner.Empty() && (!inner.ReadASN1BitStringAsBytes(&pubBytes) || !inner.Empty()) { - return errors.New("sm9: invalid sign master public key asn1 data") + return nil,errors.New("sm9: invalid sign master public key asn1 data") } } else if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { - return errors.New("sm9: invalid sign user private key asn1 data") + return nil, errors.New("sm9: invalid sign user private key asn1 data") } - err := priv.UnmarshalRaw(bytes) + + priv, err := UnmarshalSignPrivateKeyRaw(bytes) if err != nil { - return err + return nil, err } if len(pubBytes) > 0 { - masterPK := new(SignMasterPublicKey) - err = masterPK.UnmarshalRaw(pubBytes) + masterPK, err := UnmarshalSignMasterPublicKeyRaw(pubBytes) if err != nil { - return err + return nil, err } - priv.SetMasterPublicKey(masterPK) + priv.setMasterPublicKey(masterPK) } - return nil + return priv, nil } // GenerateEncryptMasterKey generates a master public and private key pair for encryption usage. @@ -297,8 +301,8 @@ func (master *EncryptMasterPrivateKey) MarshalASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalASN1 unmarsal der data to encrypt master private key -func (master *EncryptMasterPrivateKey) UnmarshalASN1(der []byte) error { +// UnmarshalEncryptMasterPrivateKeyASN1 unmarsal der data to encrypt master private key +func UnmarshalEncryptMasterPrivateKeyASN1(der []byte) (*EncryptMasterPrivateKey, error) { input := cryptobyte.String(der) d := &big.Int{} var inner cryptobyte.String @@ -307,21 +311,20 @@ func (master *EncryptMasterPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1(&inner, cryptobyte_asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1Integer(d) { - return errors.New("sm9: invalid encrypt master private key asn1 data") + return nil, errors.New("sm9: invalid encrypt master private key asn1 data") } // Just parse it, did't validate it if !inner.Empty() && (!inner.ReadASN1BitStringAsBytes(&pubBytes) || !inner.Empty()) { - return errors.New("sm9: invalid encrypt master public key asn1 data") + return nil, errors.New("sm9: invalid encrypt master public key asn1 data") } } else if !input.ReadASN1Integer(d) || !input.Empty() { - return errors.New("sm9: invalid encrypt master private key asn1 data") + return nil, errors.New("sm9: invalid encrypt master private key asn1 data") } - var err error - master.privateKey, err = sm9.NewEncryptMasterPrivateKey(d.Bytes()) + privateKey, err := sm9.NewEncryptMasterPrivateKey(d.Bytes()) if err != nil { - return err + return nil, err } - return nil + return &EncryptMasterPrivateKey{privateKey: privateKey}, nil } // Equal compares the receiver EncryptMasterPublicKey with another EncryptMasterPublicKey @@ -352,25 +355,28 @@ func (pub *EncryptMasterPublicKey) MarshalCompressedASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalRaw unmarsal raw bytes data to encrypt master public key -func (pub *EncryptMasterPublicKey) UnmarshalRaw(bytes []byte) error { - if pub.publicKey == nil { - pub.publicKey = new(sm9.EncryptMasterPublicKey) +// UnmarshalEncryptMasterPublicKeyRaw unmarsal raw bytes data to encrypt master public key +func UnmarshalEncryptMasterPublicKeyRaw(bytes []byte) (*EncryptMasterPublicKey, error) { + pub := new(EncryptMasterPublicKey) + pub.publicKey = new(sm9.EncryptMasterPublicKey) + err := pub.publicKey.UnmarshalRaw(bytes) + if err != nil { + return nil, err } - return pub.publicKey.UnmarshalRaw(bytes) + return pub, nil } -// ParseFromPEM just for GMSSL, there are no Algorithm pkix.AlgorithmIdentifier -func (pub *EncryptMasterPublicKey) ParseFromPEM(data []byte) error { +// ParseEncryptMasterPublicKeyPEM just for GMSSL, there are no Algorithm pkix.AlgorithmIdentifier +func ParseEncryptMasterPublicKeyPEM(data []byte) (*EncryptMasterPublicKey, error) { block, _ := pem.Decode([]byte(data)) if block == nil { - return errors.New("sm9: failed to parse PEM block") + return nil, errors.New("sm9: failed to parse PEM block") } - return pub.UnmarshalASN1(block.Bytes) + return UnmarshalEncryptMasterPublicKeyASN1(block.Bytes) } -// UnmarshalASN1 unmarsal der data to encrypt master public key -func (pub *EncryptMasterPublicKey) UnmarshalASN1(der []byte) error { +// UnmarshalEncryptMasterPublicKeyASN1 unmarsal der data to encrypt master public key +func UnmarshalEncryptMasterPublicKeyASN1(der []byte) (*EncryptMasterPublicKey, error) { var bytes []byte var inner cryptobyte.String input := cryptobyte.String(der) @@ -379,12 +385,12 @@ func (pub *EncryptMasterPublicKey) UnmarshalASN1(der []byte) error { !input.Empty() || !inner.ReadASN1BitStringAsBytes(&bytes) || !inner.Empty() { - return errors.New("sm9: invalid encrypt master public key asn1 data") + return nil, errors.New("sm9: invalid encrypt master public key asn1 data") } } else if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { - return errors.New("sm9: invalid encrypt master public key asn1 data") + return nil, errors.New("sm9: invalid encrypt master public key asn1 data") } - return pub.UnmarshalRaw(bytes) + return UnmarshalEncryptMasterPublicKeyRaw(bytes) } // MasterPublic returns the master public key corresponding to priv. @@ -392,8 +398,8 @@ func (priv *EncryptPrivateKey) MasterPublic() *EncryptMasterPublicKey { return &EncryptMasterPublicKey{priv.privateKey.MasterPublic()} } -// SetMasterPublicKey bind the encrypt master public key to it. -func (priv *EncryptPrivateKey) SetMasterPublicKey(pub *EncryptMasterPublicKey) { +// setMasterPublicKey bind the encrypt master public key to it. +func (priv *EncryptPrivateKey) setMasterPublicKey(pub *EncryptMasterPublicKey) { priv.privateKey.SetMasterPublicKey(pub.publicKey) } @@ -413,18 +419,21 @@ func (priv *EncryptPrivateKey) MarshalCompressedASN1() ([]byte, error) { return b.Bytes() } -// UnmarshalRaw unmarsal raw bytes data to encrypt user private key +// UnmarshalEncryptPrivateKeyRaw unmarsal raw bytes data to encrypt user private key // Note, priv's EncryptMasterPublicKey should be handled separately. -func (priv *EncryptPrivateKey) UnmarshalRaw(bytes []byte) error { - if priv.privateKey == nil { - priv.privateKey = new(sm9.EncryptPrivateKey) +func UnmarshalEncryptPrivateKeyRaw(bytes []byte) (*EncryptPrivateKey, error) { + priv := new(EncryptPrivateKey) + priv.privateKey = new(sm9.EncryptPrivateKey) + err := priv.privateKey.UnmarshalRaw(bytes) + if err != nil { + return nil, err } - return priv.privateKey.UnmarshalRaw(bytes) + return priv, nil } -// UnmarshalASN1 unmarsal der data to encrypt user private key +// UnmarshalEncryptPrivateKeyASN1 unmarsal der data to encrypt user private key // Note, priv's EncryptMasterPublicKey should be handled separately. -func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error { +func UnmarshalEncryptPrivateKeyASN1(der []byte) (*EncryptPrivateKey, error) { var bytes []byte var pubBytes []byte var inner cryptobyte.String @@ -433,27 +442,26 @@ func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1(&inner, cryptobyte_asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1BitStringAsBytes(&bytes) { - return errors.New("sm9: invalid encrypt user private key asn1 data") + return nil, errors.New("sm9: invalid encrypt user private key asn1 data") } if !inner.Empty() && (!inner.ReadASN1BitStringAsBytes(&pubBytes) || !inner.Empty()) { - return errors.New("sm9: invalid encrypt master public key asn1 data") + return nil, errors.New("sm9: invalid encrypt master public key asn1 data") } } else if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { - return errors.New("sm9: invalid encrypt user private key asn1 data") + return nil, errors.New("sm9: invalid encrypt user private key asn1 data") } - err := priv.UnmarshalRaw(bytes) + priv, err := UnmarshalEncryptPrivateKeyRaw(bytes) if err != nil { - return err + return nil, err } if len(pubBytes) > 0 { - masterPK := new(EncryptMasterPublicKey) - err = masterPK.UnmarshalRaw(pubBytes) + masterPK, err := UnmarshalEncryptMasterPublicKeyRaw(pubBytes) if err != nil { - return err + return nil, err } - priv.SetMasterPublicKey(masterPK) + priv.setMasterPublicKey(masterPK) } - return nil + return priv, nil } // Equal compares the receiver EncryptPrivateKey with another EncryptPrivateKey x diff --git a/sm9/sm9_key_test.go b/sm9/sm9_key_test.go index cab7cc4..c057c22 100644 --- a/sm9/sm9_key_test.go +++ b/sm9/sm9_key_test.go @@ -19,8 +19,8 @@ func TestSignMasterPrivateKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - masterKey2 := new(SignMasterPrivateKey) - err = masterKey2.UnmarshalASN1(der) + + masterKey2, err := UnmarshalSignMasterPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -38,8 +38,7 @@ func TestSignMasterPublicKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - pub2 := new(SignMasterPublicKey) - err = pub2.UnmarshalASN1(der) + pub2, err := UnmarshalSignMasterPublicKeyASN1(der) if err != nil { t.Fatal(err) } @@ -57,8 +56,7 @@ func TestSignMasterPublicKeyMarshalCompressedASN1(t *testing.T) { if err != nil { t.Fatal(err) } - pub2 := new(SignMasterPublicKey) - err = pub2.UnmarshalASN1(der) + pub2, err := UnmarshalSignMasterPublicKeyASN1(der) if err != nil { t.Fatal(err) } @@ -82,8 +80,7 @@ func TestSignUserPrivateKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - userKey2 := new(SignPrivateKey) - err = userKey2.UnmarshalASN1(der) + userKey2, err := UnmarshalSignPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -107,8 +104,7 @@ func TestSignUserPrivateKeyMarshalCompressedASN1(t *testing.T) { if err != nil { t.Fatal(err) } - userKey2 := new(SignPrivateKey) - err = userKey2.UnmarshalASN1(der) + userKey2, err := UnmarshalSignPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -126,8 +122,7 @@ func TestEncryptMasterPrivateKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - masterKey2 := new(EncryptMasterPrivateKey) - err = masterKey2.UnmarshalASN1(der) + masterKey2, err := UnmarshalEncryptMasterPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -145,8 +140,7 @@ func TestEncryptMasterPublicKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - pub2 := new(EncryptMasterPublicKey) - err = pub2.UnmarshalASN1(der) + pub2, err := UnmarshalEncryptMasterPublicKeyASN1(der) if err != nil { t.Fatal(err) } @@ -164,8 +158,7 @@ func TestEncryptMasterPublicKeyMarshalCompressedASN1(t *testing.T) { if err != nil { t.Fatal(err) } - pub2 := new(EncryptMasterPublicKey) - err = pub2.UnmarshalASN1(der) + pub2, err := UnmarshalEncryptMasterPublicKeyASN1(der) if err != nil { t.Fatal(err) } @@ -189,8 +182,7 @@ func TestEncryptUserPrivateKeyMarshalASN1(t *testing.T) { if err != nil { t.Fatal(err) } - userKey2 := new(EncryptPrivateKey) - err = userKey2.UnmarshalASN1(der) + userKey2, err := UnmarshalEncryptPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -214,8 +206,7 @@ func TestEncryptUserPrivateKeyMarshalCompressedASN1(t *testing.T) { if err != nil { t.Fatal(err) } - userKey2 := new(EncryptPrivateKey) - err = userKey2.UnmarshalASN1(der) + userKey2, err := UnmarshalEncryptPrivateKeyASN1(der) if err != nil { t.Fatal(err) } @@ -266,8 +257,7 @@ Ri1gDhueE6gkoeZ4HHUu1wfhRbKRF8okwSO933f/ZSpLlYu1P7/ckw== ` func TestParseSM9SignMasterPublicKey(t *testing.T) { - key := new(SignMasterPublicKey) - err := key.ParseFromPEM([]byte(sm9SignMasterPublicKeyFromGMSSL)) + key, err := ParseSignMasterPublicKeyPEM([]byte(sm9SignMasterPublicKeyFromGMSSL)) if err != nil { t.Fatal(err) } @@ -297,8 +287,7 @@ tYwoUdCETdYJwxiKXlI1jytVTuuT2Q== ` func TestParseSM9EncryptMasterPublicKey(t *testing.T) { - key := new(EncryptMasterPublicKey) - err := key.ParseFromPEM([]byte(sm9EncMasterPublicKeyFromGMSSL)) + key, err := ParseEncryptMasterPublicKeyPEM([]byte(sm9EncMasterPublicKeyFromGMSSL)) if err != nil { t.Fatal(err) } diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 3466547..1fb601d 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -1,14 +1,16 @@ -package sm9 +package sm9_test import ( "bytes" "crypto/rand" "encoding/hex" "testing" + + "github.com/emmansun/gmsm/sm9" ) func TestSignASN1(t *testing.T) { - masterKey, err := GenerateSignMasterKey(rand.Reader) + masterKey, err := sm9.GenerateSignMasterKey(rand.Reader) hashed := []byte("Chinese IBS standard") uid := []byte("emmansun") hid := byte(0x01) @@ -32,29 +34,8 @@ func TestSignASN1(t *testing.T) { } } -func TestParseInvalidASN1(t *testing.T) { - tests := []struct { - name string - sigHex string - }{ - // TODO: Add test cases. - {"invalid point format", "30660420723a8b38dd2441c2aa1c3ec092eaa34996c53bf9ca7515272395c012ab6e6e070342000C389fc45b711d9dfd9d91958f64d89d3528cf577c6dc2bc792c2969188e76865e16c2d85419f8f923a0e77c7f269c0eeb97b6c4d7e2735189180ec719a380fe1d"}, - {"invalid point encoding length", "30660420723a8b38dd2441c2aa1c3ec092eaa34996c53bf9ca7515272395c012ab6e6e0703420004389fc45b711d9dfd9d91958f64d89d3528cf577c6dc2bc792c2969188e76865e16c2d85419f8f923a0e77c7f269c0eeb97b6c4d7e2735189180ec719a380fe"}, - } - for _, tt := range tests { - sig, err := hex.DecodeString(tt.sigHex) - if err != nil { - t.Fatal(err) - } - _, _, err = parseSignature(sig) - if err == nil { - t.Errorf("%s should be failed", tt.name) - } - } -} - func TestWrapKey(t *testing.T) { - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -80,7 +61,7 @@ func TestWrapKey(t *testing.T) { } func TestWrapKeyASN1(t *testing.T) { - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -95,12 +76,12 @@ func TestWrapKeyASN1(t *testing.T) { t.Fatal(err) } - key1, cipher, err := UnmarshalSM9KeyPackage(keyPackage) + key1, cipher, err := sm9.UnmarshalSM9KeyPackage(keyPackage) if err != nil { t.Fatal(err) } - key2, err := UnwrapKey(userKey, uid, cipher, 16) + key2, err := sm9.UnwrapKey(userKey, uid, cipher, 16) if err != nil { t.Fatal(err) } @@ -112,7 +93,7 @@ func TestWrapKeyASN1(t *testing.T) { func TestEncryptDecrypt(t *testing.T) { plaintext := []byte("Chinese IBE standard") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -122,16 +103,16 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - encTypes := []EncrypterOpts{ - DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts, + encTypes := []sm9.EncrypterOpts{ + sm9.DefaultEncrypterOpts, sm9.SM4ECBEncrypterOpts, sm9.SM4CBCEncrypterOpts, sm9.SM4CFBEncrypterOpts, sm9.SM4OFBEncrypterOpts, } for _, opts := range encTypes { - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) + cipher, err := sm9.Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) if err != nil { t.Fatal(err) } - got, err := Decrypt(userKey, uid, cipher, opts) + got, err := sm9.Decrypt(userKey, uid, cipher, opts) if err != nil { t.Fatal(err) } @@ -151,18 +132,18 @@ func TestEncryptDecrypt(t *testing.T) { } func TestEncryptEmptyPlaintext(t *testing.T) { - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { t.Fatal(err) } - encTypes := []EncrypterOpts{ - DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts, + encTypes := []sm9.EncrypterOpts{ + sm9.DefaultEncrypterOpts, sm9.SM4ECBEncrypterOpts, sm9.SM4CBCEncrypterOpts, sm9.SM4CFBEncrypterOpts, sm9.SM4OFBEncrypterOpts, } for _, opts := range encTypes { - _, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, nil, opts) - if err != ErrEmptyPlaintext { + _, err := sm9.Encrypt(rand.Reader, masterKey.Public(), uid, hid, nil, opts) + if err != sm9.ErrEmptyPlaintext { t.Fatalf("should be ErrEmptyPlaintext") } } @@ -170,7 +151,7 @@ func TestEncryptEmptyPlaintext(t *testing.T) { func TestEncryptDecryptASN1(t *testing.T) { plaintext := []byte("Chinese IBE standard") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -180,16 +161,16 @@ func TestEncryptDecryptASN1(t *testing.T) { if err != nil { t.Fatal(err) } - encTypes := []EncrypterOpts{ - DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts, + encTypes := []sm9.EncrypterOpts{ + sm9.DefaultEncrypterOpts, sm9.SM4ECBEncrypterOpts, sm9.SM4CBCEncrypterOpts, sm9.SM4CFBEncrypterOpts, sm9.SM4OFBEncrypterOpts, } for _, opts := range encTypes { - cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) + cipher, err := sm9.EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) if err != nil { t.Fatal(err) } - got, err := DecryptASN1(userKey, uid, cipher) + got, err := sm9.DecryptASN1(userKey, uid, cipher) if err != nil { t.Fatal(err) } @@ -210,7 +191,7 @@ func TestEncryptDecryptASN1(t *testing.T) { } func TestUnmarshalSM9KeyPackage(t *testing.T) { - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -225,12 +206,12 @@ func TestUnmarshalSM9KeyPackage(t *testing.T) { t.Fatal(err) } - key, cipher, err := UnmarshalSM9KeyPackage(p) + key, cipher, err := sm9.UnmarshalSM9KeyPackage(p) if err != nil { t.Fatal(err) } - key2, err := UnwrapKey(userKey, uid, cipher, 16) + key2, err := sm9.UnwrapKey(userKey, uid, cipher, 16) if err != nil { t.Fatal(err) } @@ -244,7 +225,7 @@ func TestKeyExchange(t *testing.T) { hid := byte(0x02) userA := []byte("Alice") userB := []byte("Bob") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) if err != nil { t.Fatal(err) } @@ -297,7 +278,7 @@ func TestKeyExchangeWithoutSignature(t *testing.T) { hid := byte(0x02) userA := []byte("Alice") userB := []byte("Bob") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) if err != nil { t.Fatal(err) } @@ -356,7 +337,7 @@ func BenchmarkSign(b *testing.B) { uid := []byte("emmansun") hid := byte(0x01) - masterKey, err := GenerateSignMasterKey(rand.Reader) + masterKey, err := sm9.GenerateSignMasterKey(rand.Reader) if err != nil { b.Fatal(err) } @@ -364,12 +345,12 @@ func BenchmarkSign(b *testing.B) { if err != nil { b.Fatal(err) } - SignASN1(rand.Reader, userKey, hashed) // fire precompute + sm9.SignASN1(rand.Reader, userKey, hashed) // fire precompute b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - sig, err := SignASN1(rand.Reader, userKey, hashed) + sig, err := sm9.SignASN1(rand.Reader, userKey, hashed) if err != nil { b.Fatal(err) } @@ -383,7 +364,7 @@ func BenchmarkVerify(b *testing.B) { uid := []byte("emmansun") hid := byte(0x01) - masterKey, err := GenerateSignMasterKey(rand.Reader) + masterKey, err := sm9.GenerateSignMasterKey(rand.Reader) if err != nil { b.Fatal(err) } @@ -391,14 +372,14 @@ func BenchmarkVerify(b *testing.B) { if err != nil { b.Fatal(err) } - sig, err := SignASN1(rand.Reader, userKey, hashed) + sig, err := sm9.SignASN1(rand.Reader, userKey, hashed) if err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - if !VerifyASN1(masterKey.Public(), uid, hid, hashed, sig) { + if !sm9.VerifyASN1(masterKey.Public(), uid, hid, hashed, sig) { b.Fatal("verify failed") } } @@ -406,7 +387,7 @@ func BenchmarkVerify(b *testing.B) { func BenchmarkEncrypt(b *testing.B) { plaintext := []byte("Chinese IBE standard") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -415,7 +396,7 @@ func BenchmarkEncrypt(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) + cipher, err := sm9.Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) } @@ -426,7 +407,7 @@ func BenchmarkEncrypt(b *testing.B) { func BenchmarkDecrypt(b *testing.B) { plaintext := []byte("Chinese IBE standard") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -436,14 +417,14 @@ func BenchmarkDecrypt(b *testing.B) { if err != nil { b.Fatal(err) } - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) + cipher, err := sm9.Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - got, err := Decrypt(userKey, uid, cipher, nil) + got, err := sm9.Decrypt(userKey, uid, cipher, nil) if err != nil { b.Fatal(err) } @@ -455,7 +436,7 @@ func BenchmarkDecrypt(b *testing.B) { func BenchmarkDecryptASN1(b *testing.B) { plaintext := []byte("Chinese IBE standard") - masterKey, err := GenerateEncryptMasterKey(rand.Reader) + masterKey, err := sm9.GenerateEncryptMasterKey(rand.Reader) hid := byte(0x01) uid := []byte("emmansun") if err != nil { @@ -465,14 +446,14 @@ func BenchmarkDecryptASN1(b *testing.B) { if err != nil { b.Fatal(err) } - cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) + cipher, err := sm9.EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - got, err := DecryptASN1(userKey, uid, cipher) + got, err := sm9.DecryptASN1(userKey, uid, cipher) if err != nil { b.Fatal(err) } diff --git a/smx509/pkcs8.go b/smx509/pkcs8.go index 062cb73..a38d440 100644 --- a/smx509/pkcs8.go +++ b/smx509/pkcs8.go @@ -90,20 +90,10 @@ func ParsePKCS8PrivateKey(der []byte) (key any, err error) { func parseSM9PrivateKey(privKey pkcs8) (key any, err error) { switch { case privKey.Algo.Algorithm.Equal(oidSM9Sign): - sm9SignKey := new(sm9.SignPrivateKey) - err = sm9SignKey.UnmarshalASN1(privKey.PrivateKey) - if err != nil { - return - } - key = sm9SignKey + key, err = sm9.UnmarshalSignPrivateKeyASN1(privKey.PrivateKey) return case privKey.Algo.Algorithm.Equal(oidSM9Enc): - sm9EncKey := new(sm9.EncryptPrivateKey) - err = sm9EncKey.UnmarshalASN1(privKey.PrivateKey) - if err != nil { - return - } - key = sm9EncKey + key, err = sm9.UnmarshalEncryptPrivateKeyASN1(privKey.PrivateKey) return default: bytes := privKey.Algo.Parameters.FullBytes @@ -114,20 +104,10 @@ func parseSM9PrivateKey(privKey pkcs8) (key any, err error) { } switch { case oidSM9Sign.Equal(*detailOID): - sm9SignMasterKey := new(sm9.SignMasterPrivateKey) - err = sm9SignMasterKey.UnmarshalASN1(privKey.PrivateKey) - if err != nil { - return - } - key = sm9SignMasterKey + key, err = sm9.UnmarshalSignMasterPrivateKeyASN1(privKey.PrivateKey) return case oidSM9Enc.Equal(*detailOID): - sm9EncMasterKey := new(sm9.EncryptMasterPrivateKey) - err = sm9EncMasterKey.UnmarshalASN1(privKey.PrivateKey) - if err != nil { - return - } - key = sm9EncMasterKey + key, err = sm9.UnmarshalEncryptMasterPrivateKeyASN1(privKey.PrivateKey) return } return nil, errors.New("not support yet")