From 167f0e0b11a633c21a60551a1e29ed180ce849bc Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 15 Dec 2023 13:06:53 +0800 Subject: [PATCH] sm2: #189, #190, #191 --- docs/sm2.md | 40 +++++++++---- ecdh/ecdh_test.go | 11 ++++ ecdh/sm2ec.go | 34 +++++------ sm2/example_test.go | 78 +++++++++++++------------ sm2/sm2.go | 138 ++++++++++++++++++++++++++++++++++---------- sm2/sm2_test.go | 115 ++++++++++++++++++++++++++++++++++-- 6 files changed, 315 insertions(+), 101 deletions(-) diff --git a/docs/sm2.md b/docs/sm2.md index 063713b..1ca8555 100644 --- a/docs/sm2.md +++ b/docs/sm2.md @@ -55,11 +55,15 @@ func getPublicKey(pemContent []byte) (any, error) { 有些应用可能会直接存储公钥的曲线点X, Y 坐标值,这时候,您可以通过以下类似方法构造公钥(假设输入的是点的非压缩序列化字节数组): ```go - // Create public key from point (uncompressed) - publicKeyCopy := new(ecdsa.PublicKey) - publicKeyCopy.Curve = sm2.P256() - publicKeyCopy.X, publicKeyCopy.Y = elliptic.Unmarshal(publicKeyCopy.Curve, pointBytes) - +func ExampleNewPublicKey() { + keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + pub, err := sm2.NewPublicKey(keypoints) + if err != nil { + log.Fatalf("fail to new public key %v", err) + } + fmt.Printf("%x\n", elliptic.Marshal(sm2.P256(), pub.X, pub.Y)) + // Output: 048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1 +} ``` 当然,您也可以使用ecdh包下的方法```ecdh.P256().NewPublicKey```来构造,目前只支持非压缩方式。 @@ -87,13 +91,25 @@ func getPublicKey(pemContent []byte) (any, error) { 有些系统可能会直接存储、得到私钥的字节数组,那么您可以使用如下方法来构造私钥: ```go - bytes, _ := hex.DecodeString("4e85afbc996fdc67b4f05880bd9c0d037932649215ae10cf7085720b6571054c") - d := new(big.Int).SetBytes(bytes) - // Create private key from *big.Int - priv := new(PrivateKey) - priv.Curve = sm2.P256() - priv.D = d - priv.PublicKey.X, priv.PublicKey.Y = priv.ScalarBaseMult(priv.D.Bytes()) +func ExampleNewPrivateKey() { + keyBytes, _ := hex.DecodeString("6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85") + priv, err := sm2.NewPrivateKey(keyBytes) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } + fmt.Printf("%x\n", priv.D.Bytes()) + // Output: 6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85 +} + +func ExampleNewPrivateKeyFromInt() { + key := big.NewInt(0x123456) + priv, err := sm2.NewPrivateKeyFromInt(key) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } + fmt.Printf("%x\n", priv.D.Bytes()) + // Output: 123456 +} ``` 当然,你也可以使用ecdh包的方法```ecdh.P256().NewPrivateKey```来构造私钥,您要确保输入的字节数组是256位(16字节)的,如果不是,请先自行处理。 diff --git a/ecdh/ecdh_test.go b/ecdh/ecdh_test.go index e8b436d..fbaadfc 100644 --- a/ecdh/ecdh_test.go +++ b/ecdh/ecdh_test.go @@ -32,6 +32,17 @@ func hexDecode(t *testing.T, s string) []byte { return b } +func TestNewPrivateKeyWithOrderMinus1(t *testing.T) { + _, err := ecdh.P256().NewPrivateKey([]byte{ + 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x72, 0x03, 0xdf, 0x6b, 0x21, 0xc6, 0x05, 0x2b, + 0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x22}) + if err == nil || err.Error() != "ecdh: invalid private key" { + t.Errorf("expected invalid private key") + } +} + func TestECDH(t *testing.T) { aliceKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { diff --git a/ecdh/sm2ec.go b/ecdh/sm2ec.go index 64fd18a..b6cc3ed 100644 --- a/ecdh/sm2ec.go +++ b/ecdh/sm2ec.go @@ -13,12 +13,12 @@ import ( ) type sm2Curve struct { - name string - newPoint func() *sm2ec.SM2P256Point - scalarOrder []byte - constantA []byte - constantB []byte - generator []byte + name string + newPoint func() *sm2ec.SM2P256Point + scalarOrderMinus1 []byte + constantA []byte + constantB []byte + generator []byte } func (c *sm2Curve) String() string { @@ -26,7 +26,7 @@ func (c *sm2Curve) String() string { } func (c *sm2Curve) GenerateKey(rand io.Reader) (*PrivateKey, error) { - key := make([]byte, len(c.scalarOrder)) + key := make([]byte, len(c.scalarOrderMinus1)) randutil.MaybeReadByte(rand) for { @@ -48,10 +48,10 @@ func (c *sm2Curve) GenerateKey(rand io.Reader) (*PrivateKey, error) { } func (c *sm2Curve) NewPrivateKey(key []byte) (*PrivateKey, error) { - if len(key) != len(c.scalarOrder) { + if len(key) != len(c.scalarOrderMinus1) { return nil, errors.New("ecdh: invalid private key size") } - if subtle.ConstantTimeAllZero(key) || !isLess(key, c.scalarOrder) { + if subtle.ConstantTimeAllZero(key) || !isLess(key, c.scalarOrderMinus1) { return nil, errInvalidPrivateKey } return &PrivateKey{ @@ -181,19 +181,19 @@ func (c *sm2Curve) sm2za(md hash.Hash, pub *PublicKey, uid []byte) ([]byte, erro func P256() Curve { return sm2P256 } var sm2P256 = &sm2Curve{ - name: "sm2p256v1", - newPoint: sm2ec.NewSM2P256Point, - scalarOrder: sm2P256Order, - generator: sm2Generator, - constantA: sm2ConstantA, - constantB: sm2ConstantB, + name: "sm2p256v1", + newPoint: sm2ec.NewSM2P256Point, + scalarOrderMinus1: sm2P256OrderMinus1, + generator: sm2Generator, + constantA: sm2ConstantA, + constantB: sm2ConstantB, } -var sm2P256Order = []byte{ +var sm2P256OrderMinus1 = []byte{ 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x72, 0x03, 0xdf, 0x6b, 0x21, 0xc6, 0x05, 0x2b, - 0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x23} + 0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x22} var sm2Generator = []byte{ 0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19, 0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94, diff --git a/sm2/example_test.go b/sm2/example_test.go index 367352b..e7ab626 100644 --- a/sm2/example_test.go +++ b/sm2/example_test.go @@ -1,7 +1,6 @@ package sm2_test import ( - "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "encoding/hex" @@ -16,29 +15,34 @@ import ( "golang.org/x/crypto/cryptobyte/asn1" ) -// This example method is just for reference, it's NOT a standard method for key transmission. -// In general, private key will be encoded/formatted with PKCS8, public key will be encoded/formatted with a SubjectPublicKeyInfo structure -// (see RFC 5280, Section 4.1). -func Example_createKeysFromRawValue() { - key, _ := sm2.GenerateKey(rand.Reader) +func ExampleNewPrivateKey() { + keyBytes, _ := hex.DecodeString("6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85") + priv, err := sm2.NewPrivateKey(keyBytes) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } + fmt.Printf("%x\n", priv.D.Bytes()) + // Output: 6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85 +} - d := new(big.Int).SetBytes(key.D.Bytes()) // here we do NOT check if the d is in (0, N) or not - // Create private key from *big.Int - keyCopy := new(sm2.PrivateKey) - keyCopy.Curve = sm2.P256() - keyCopy.D = d - keyCopy.PublicKey.X, keyCopy.PublicKey.Y = keyCopy.ScalarBaseMult(keyCopy.D.Bytes()) - if !key.Equal(keyCopy) { - log.Fatalf("private key and copy should be equal") +func ExampleNewPrivateKeyFromInt() { + key := big.NewInt(0x123456) + priv, err := sm2.NewPrivateKeyFromInt(key) + if err != nil { + log.Fatalf("fail to new private key %v", err) } - pointBytes := elliptic.Marshal(key.Curve, key.X, key.Y) - // Create public key from point (uncompressed) - publicKeyCopy := new(ecdsa.PublicKey) - publicKeyCopy.Curve = sm2.P256() - publicKeyCopy.X, publicKeyCopy.Y = elliptic.Unmarshal(publicKeyCopy.Curve, pointBytes) - if !key.PublicKey.Equal(publicKeyCopy) { - log.Fatalf("public key and copy should be equal") + fmt.Printf("%x\n", priv.D.Bytes()) + // Output: 123456 +} + +func ExampleNewPublicKey() { + keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + pub, err := sm2.NewPublicKey(keypoints) + if err != nil { + log.Fatalf("fail to new public key %v", err) } + fmt.Printf("%x\n", elliptic.Marshal(sm2.P256(), pub.X, pub.Y)) + // Output: 048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1 } // This method provide a sample to handle ASN1 ciphertext ends with extra bytes. @@ -91,11 +95,10 @@ func ExamplePrivateKey_Sign_forceSM2() { toSign := []byte("ShangMi SM2 Sign Standard") // real private key should be from secret storage privKey, _ := hex.DecodeString("6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85") - d := new(big.Int).SetBytes(privKey) - testkey := new(sm2.PrivateKey) - testkey.Curve = sm2.P256() - testkey.D = d - testkey.PublicKey.X, testkey.PublicKey.Y = testkey.ScalarBaseMult(testkey.D.Bytes()) + testkey, err := sm2.NewPrivateKey(privKey) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } // force SM2 sign standard and use default UID sig, err := testkey.Sign(rand.Reader, toSign, sm2.DefaultSM2SignerOpts) @@ -112,9 +115,10 @@ func ExamplePrivateKey_Sign_forceSM2() { func ExampleVerifyASN1WithSM2() { // real public key should be from cert or public key pem file keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") - testkey := new(ecdsa.PublicKey) - testkey.Curve = sm2.P256() - testkey.X, testkey.Y = elliptic.Unmarshal(testkey.Curve, keypoints) + testkey, err := sm2.NewPublicKey(keypoints) + if err != nil { + log.Fatalf("fail to new public key %v", err) + } toSign := []byte("ShangMi SM2 Sign Standard") signature, _ := hex.DecodeString("304402205b3a799bd94c9063120d7286769220af6b0fa127009af3e873c0e8742edc5f890220097968a4c8b040fd548d1456b33f470cabd8456bfea53e8a828f92f6d4bdcd77") @@ -128,9 +132,10 @@ func ExampleVerifyASN1WithSM2() { func ExampleEncryptASN1() { // real public key should be from cert or public key pem file keypoints, _ := hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") - testkey := new(ecdsa.PublicKey) - testkey.Curve = sm2.P256() - testkey.X, testkey.Y = elliptic.Unmarshal(testkey.Curve, keypoints) + testkey, err := sm2.NewPublicKey(keypoints) + if err != nil { + log.Fatalf("fail to new public key %v", err) + } secretMessage := []byte("send reinforcements, we're going to advance") @@ -153,11 +158,10 @@ func ExamplePrivateKey_Decrypt() { // real private key should be from secret storage privKey, _ := hex.DecodeString("6c5a0a0b2eed3cbec3e4f1252bfe0e28c504a1c6bf1999eebb0af9ef0f8e6c85") - d := new(big.Int).SetBytes(privKey) - testkey := new(sm2.PrivateKey) - testkey.Curve = sm2.P256() - testkey.D = d - testkey.PublicKey.X, testkey.PublicKey.Y = testkey.ScalarBaseMult(testkey.D.Bytes()) + testkey, err := sm2.NewPrivateKey(privKey) + if err != nil { + log.Fatalf("fail to new private key %v", err) + } plaintext, err := testkey.Decrypt(nil, ciphertext, nil) if err != nil { diff --git a/sm2/sm2.go b/sm2/sm2.go index 962a4e0..82c148c 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -44,6 +44,9 @@ const ( // It implemented both crypto.Decrypter and crypto.Signer interfaces. type PrivateKey struct { ecdsa.PrivateKey + // inverseOfkeyPlus1 is set under inverseOfkeyPlus1Once + inverseOfkeyPlus1 *bigmod.Nat + inverseOfkeyPlus1Once sync.Once } type pointMarshalMode byte @@ -239,7 +242,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt } var retryCount int = 0 for { - k, C1, err := randomPoint(c, random) + k, C1, err := randomPoint(c, random, false) if err != nil { return nil, err } @@ -311,7 +314,7 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { randutil.MaybeReadByte(rand) c := p256() - k, Q, err := randomPoint(c, rand) + k, Q, err := randomPoint(c, rand, true) if err != nil { return nil, err } @@ -326,6 +329,61 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { return priv, nil } +// NewPrivateKey checks that key is valid and returns a SM2 PrivateKey. +// +// key - the private key byte slice, the length must be 32 for SM2. +func NewPrivateKey(key []byte) (*PrivateKey, error) { + c := p256() + if len(key) != c.N.Size() { + return nil, errors.New("sm2: invalid private key size") + } + k, err := bigmod.NewNat().SetBytes(key, c.N) + if err != nil || k.IsZero() == 1 || k.Equal(c.nMinus1) == 1 { + return nil, errInvalidPrivateKey + } + p, err := c.newPoint().ScalarBaseMult(k.Bytes(c.N)) + if err != nil { + return nil, err + } + priv := new(PrivateKey) + priv.PublicKey.Curve = c.curve + priv.D = new(big.Int).SetBytes(k.Bytes(c.N)) + priv.PublicKey.X, priv.PublicKey.Y, err = c.pointToAffine(p) + if err != nil { + return nil, err + } + return priv, nil +} + +// NewPrivateKeyFromInt checks that key is valid and returns a SM2 PrivateKey. +func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) { + if key == nil { + return nil, errors.New("sm2: invalid private key size") + } + keyBytes := make([]byte, p256().N.Size()) + return NewPrivateKey(key.FillBytes(keyBytes)) +} + +// NewPublicKey checks that key is valid and returns a PublicKey. +func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) { + c := p256() + // Reject the point at infinity and compressed encodings. + if len(key) == 0 || key[0] != 4 { + return nil, errors.New("sm2: invalid public key") + } + // SetBytes also checks that the point is on the curve. + p, err := c.newPoint().SetBytes(key) + if err != nil { + return nil, err + } + k := new(ecdsa.PublicKey) + k.X, k.Y, err = c.pointToAffine(p) + if err != nil { + return nil, err + } + return k, nil +} + // Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}. // Compliance with GB/T 32918.4-2016. func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { @@ -486,7 +544,7 @@ func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { } // CalculateSM2Hash calculates hash value for data including uid and public key parameters -// according standards. +// according standards. // // uid can be nil, then it will use default uid (1234567812345678) func CalculateSM2Hash(pub *ecdsa.PublicKey, data, uid []byte) ([]byte, error) { @@ -533,35 +591,52 @@ func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte, opts crypto.SignerO } } -func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) { - e := bigmod.NewNat() - hashToNat(c, e, hash) +func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, error) { var ( - k, r, s, dp1Inv, oneNat *bigmod.Nat - R *_sm2ec.SM2P256Point + err error + dp1Inv, oneNat *bigmod.Nat + dp1Bytes []byte + ) + priv.inverseOfkeyPlus1Once.Do(func() { + oneNat, _ = bigmod.NewNat().SetBytes(one.Bytes(), c.N) + dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) + if err == nil { + dp1Inv.Add(oneNat, c.N) + if dp1Inv.IsZero() == 1 { // make sure private key is NOT N-1 + err = errInvalidPrivateKey + } else { + dp1Bytes, err = _sm2ec.P256OrdInverse(dp1Inv.Bytes(c.N)) + if err == nil { + priv.inverseOfkeyPlus1, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N) + } + } + } + }) + if err != nil { + return nil, errInvalidPrivateKey + } + return priv.inverseOfkeyPlus1, nil +} + +func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) { + // get/compute inv(d+1) + dp1Inv, err := priv.inverseOfPrivateKeyPlus1(c) + if err != nil { + return nil, err + } + + var ( + k, r, s *bigmod.Nat + R *_sm2ec.SM2P256Point ) - oneNat, err = bigmod.NewNat().SetBytes(one.Bytes(), c.N) - if err != nil { - return nil, err - } - dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) - if err != nil { - return nil, err - } - dp1Inv.Add(oneNat, c.N) - dp1Bytes, err := _sm2ec.P256OrdInverse(dp1Inv.Bytes(c.N)) - if err != nil { - return nil, err - } - dp1Inv, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N) - if err != nil { - panic("sm2: internal error: P256OrdInverse produced an invalid value") - } + // hash to int + e := bigmod.NewNat() + hashToNat(c, e, hash) for { for { - k, R, err = randomPoint(c, rand) + k, R, err = randomPoint(c, rand, false) if err != nil { return nil, err } @@ -792,7 +867,7 @@ func curveToECDH(c elliptic.Curve) ecdh.Curve { // randomPoint returns a random scalar and the corresponding point using the // procedure given in FIPS 186-4, Appendix B.5.2 (rejection sampling). -func randomPoint(c *sm2Curve, rand io.Reader) (k *bigmod.Nat, p *_sm2ec.SM2P256Point, err error) { +func randomPoint(c *sm2Curve, rand io.Reader, checkOrderMinus1 bool) (k *bigmod.Nat, p *_sm2ec.SM2P256Point, err error) { k = bigmod.NewNat() for { b := make([]byte, c.N.Size()) @@ -813,11 +888,10 @@ func randomPoint(c *sm2Curve, rand io.Reader) (k *bigmod.Nat, p *_sm2ec.SM2P256P b[0] >>= excess } - // FIPS 186-4 makes us check k <= N - 2 and then add one. - // Checking 0 < k <= N - 1 is strictly equivalent. + // Checking 0 < k <= N - 2. // None of this matters anyway because the chance of selecting // zero is cryptographically negligible. - if _, err = k.SetBytes(b, c.N); err == nil && k.IsZero() == 0 { + if _, err = k.SetBytes(b, c.N); err == nil && k.IsZero() == 0 && (!checkOrderMinus1 || k.Equal(c.nMinus1) == 0) { break } @@ -838,6 +912,7 @@ type sm2Curve struct { newPoint func() *_sm2ec.SM2P256Point curve elliptic.Curve N *bigmod.Modulus + nMinus1 *bigmod.Nat nMinus2 []byte } @@ -891,4 +966,7 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) { c.curve = curve c.N, _ = bigmod.NewModulusFromBig(params.N) c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes() + c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N) } + +var errInvalidPrivateKey = errors.New("sm2: invalid private key") diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 803719e..3abbd58 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -16,6 +16,99 @@ import ( "github.com/emmansun/gmsm/sm3" ) +func TestNewPrivateKey(t *testing.T) { + c := p256() + // test nil + _, err := NewPrivateKey(nil) + if err == nil || err.Error() != "sm2: invalid private key size" { + t.Errorf("should throw sm2: invalid private key size") + } + // test all zero + key := make([]byte, c.N.Size()) + _, err = NewPrivateKey(key) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } + // test N-1 + _, err = NewPrivateKey(c.nMinus1.Bytes(c.N)) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } + // test N + _, err = NewPrivateKey(P256().Params().N.Bytes()) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } + // test 1 + key[31] = 1 + _, err = NewPrivateKey(key) + if err != nil { + t.Fatal(err) + } + // test N-2 + _, err = NewPrivateKey(c.nMinus2) + if err != nil { + t.Error(err) + } +} + +func TestNewPrivateKeyFromInt(t *testing.T) { + // test nil + _, err := NewPrivateKeyFromInt(nil) + if err == nil || err.Error() != "sm2: invalid private key size" { + t.Errorf("should throw sm2: invalid private key size") + } + // test 1 + _, err = NewPrivateKeyFromInt(big.NewInt(1)) + if err != nil { + t.Fatal(err) + } + // test N + _, err = NewPrivateKeyFromInt(P256().Params().N) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } + + // test N + 1 + _, err = NewPrivateKeyFromInt(new(big.Int).Add(P256().Params().N, big.NewInt(1))) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } + + c := p256() + // test N - 1 + _, err = NewPrivateKeyFromInt(new(big.Int).SetBytes(c.nMinus1.Bytes(c.N))) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("should throw errInvalidPrivateKey") + } +} + +func TestNewPublicKey(t *testing.T) { + // test nil + _, err := NewPublicKey(nil) + if err == nil || err.Error() != "sm2: invalid public key" { + t.Errorf("should throw sm2: invalid public key") + } + // test without point format prefix byte + keypoints, _ := hex.DecodeString("8356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + _, err = NewPublicKey(keypoints) + if err == nil || err.Error() != "sm2: invalid public key" { + t.Errorf("should throw sm2: invalid public key") + } + // test correct point + keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") + _, err = NewPublicKey(keypoints) + if err != nil { + t.Fatal(err) + } + // test point not on curve + keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba2") + _, err = NewPublicKey(keypoints) + if err == nil || err.Error() != "point not on SM2 P256 curve" { + t.Errorf("should throw point not on SM2 P256 curve, got %v", err) + } +} + func TestSplicingOrder(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { @@ -335,6 +428,18 @@ func TestInvalidCiphertext(t *testing.T) { } } +func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) { + priv := new(PrivateKey) + priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1)) + priv.Curve = P256() + priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes()) + + _, err := priv.inverseOfPrivateKeyPlus1(p256()) + if err == nil || err != errInvalidPrivateKey { + t.Errorf("expected invalid private key error") + } +} + func TestSignVerify(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { @@ -535,7 +640,7 @@ func TestEqual(t *testing.T) { t.Errorf("private.Public() is not Equal to public: %q", public) } if !private.Equal(private) { - t.Errorf("private key is not equal to itself: %q", private) + t.Errorf("private key is not equal to itself") } otherPriv, _ := GenerateKey(rand.Reader) @@ -571,7 +676,7 @@ func TestRandomPoint(t *testing.T) { // A sequence of all ones will generate 2^N-1, which should be rejected. // (Unless, for example, we are masking too many bits.) r := io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0xff}, 100)), rand.Reader) - if k, p, err := randomPoint(c, r); err != nil { + if k, p, err := randomPoint(c, r, false); err != nil { t.Fatal(err) } else if k.IsZero() == 1 { t.Error("k is zero") @@ -585,7 +690,7 @@ func TestRandomPoint(t *testing.T) { // A sequence of all zeroes will generate zero, which should be rejected. r = io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0}, 100)), rand.Reader) - if k, p, err := randomPoint(c, r); err != nil { + if k, p, err := randomPoint(c, r, false); err != nil { t.Fatal(err) } else if k.IsZero() == 1 { t.Error("k is zero") @@ -625,12 +730,12 @@ func BenchmarkSign_SM2(b *testing.B) { if err != nil { b.Fatal(err) } - hashed := []byte("testing") + hashed := sm3.Sum([]byte("testing")) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - sig, err := SignASN1(rand.Reader, priv, hashed, nil) + sig, err := SignASN1(rand.Reader, priv, hashed[:], nil) if err != nil { b.Fatal(err) }