diff --git a/sm2/sm2_dsa.go b/sm2/sm2_dsa.go index f550f30..614b6f2 100644 --- a/sm2/sm2_dsa.go +++ b/sm2/sm2_dsa.go @@ -172,6 +172,18 @@ func NewPrivateKey(key []byte) (*PrivateKey, error) { return priv, nil } +// ParseRawPrivateKey parses a private key encoded as a fixed-length big-endian +// integer, according to SEC 1, Version 2.0, Section 2.3.6 (sometimes referred +// to as the raw format). It returns an error if the value is not reduced modulo +// the curve's order minus one, or if it's zero. +// +// Note that private keys are more commonly encoded in ASN.1 or PKCS#8 format, +// which can be parsed with [smx509.ParseECPrivateKey] or +// [smx509.ParsePKCS8PrivateKey] (and [encoding/pem]). +func ParseRawPrivateKey(data []byte) (*PrivateKey, error) { + return NewPrivateKey(data) +} + // NewPrivateKeyFromInt creates a new SM2 private key from a given big integer. // It returns an error if the provided key is nil. func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) { @@ -208,6 +220,17 @@ func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) { return k, nil } +// ParseUncompressedPublicKey parses a public key encoded as an uncompressed +// point according to SEC 1, Version 2.0, Section 2.3.3 (also known as the X9.62 +// uncompressed format). It returns an error if the point is not in uncompressed +// form, is not on the curve, or is the point at infinity. +// +// Note that public keys are more commonly encoded in DER (or PEM) format, which +// can be parsed with [smx509.ParsePKIXPublicKey] (and [encoding/pem]). +func ParseUncompressedPublicKey(data []byte) (*ecdsa.PublicKey, error) { + return NewPublicKey(data) +} + var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} // CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA). diff --git a/sm2/sm2_dsa_test.go b/sm2/sm2_dsa_test.go index 3de5a4d..50a85f3 100644 --- a/sm2/sm2_dsa_test.go +++ b/sm2/sm2_dsa_test.go @@ -15,37 +15,37 @@ import ( "github.com/emmansun/gmsm/sm3" ) -func TestNewPrivateKey(t *testing.T) { +func TestParseRawPrivateKey(t *testing.T) { c := p256() // test nil - _, err := NewPrivateKey(nil) + _, err := ParseRawPrivateKey(nil) if err == nil || err.Error() != "sm2: invalid private key size" { t.Errorf("should throw sm2: invalid private key size") } // test all zero key := make([]byte, c.N.Size()) - _, err = NewPrivateKey(key) + _, err = ParseRawPrivateKey(key) if err == nil || err != errInvalidPrivateKey { t.Errorf("should throw errInvalidPrivateKey") } // test N-1 - _, err = NewPrivateKey(c.nMinus1.Bytes(c.N)) + _, err = ParseRawPrivateKey(c.nMinus1.Bytes(c.N)) if err == nil || err != errInvalidPrivateKey { t.Errorf("should throw errInvalidPrivateKey") } // test N - _, err = NewPrivateKey(P256().Params().N.Bytes()) + _, err = ParseRawPrivateKey(P256().Params().N.Bytes()) if err == nil || err != errInvalidPrivateKey { t.Errorf("should throw errInvalidPrivateKey") } // test 1 key[31] = 1 - _, err = NewPrivateKey(key) + _, err = ParseRawPrivateKey(key) if err != nil { t.Fatal(err) } // test N-2 - _, err = NewPrivateKey(c.nMinus2) + _, err = ParseRawPrivateKey(c.nMinus2) if err != nil { t.Error(err) } @@ -82,27 +82,27 @@ func TestNewPrivateKeyFromInt(t *testing.T) { } } -func TestNewPublicKey(t *testing.T) { +func TestParseUncompressedPublicKey(t *testing.T) { // test nil - _, err := NewPublicKey(nil) + _, err := ParseUncompressedPublicKey(nil) if err == nil || err.Error() != "sm2: invalid public key" { t.Errorf("should throw sm2: invalid public key") } // test without point format prefix byte keypoints, _ := hex.DecodeString("8356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") - _, err = NewPublicKey(keypoints) + _, err = ParseUncompressedPublicKey(keypoints) if err == nil || err.Error() != "sm2: invalid public key" { t.Errorf("should throw sm2: invalid public key") } // test correct point keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1") - _, err = NewPublicKey(keypoints) + _, err = ParseUncompressedPublicKey(keypoints) if err != nil { t.Fatal(err) } // test point not on curve keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba2") - _, err = NewPublicKey(keypoints) + _, err = ParseUncompressedPublicKey(keypoints) if err == nil || err.Error() != "point not on SM2 P256 curve" { t.Errorf("should throw point not on SM2 P256 curve, got %v", err) }