diff --git a/go.mod b/go.mod index 54a20b6..d226f55 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/emmansun/gmsm go 1.14 require ( - golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad - golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b + golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a + golang.org/x/sys v0.0.0-20210603125802-9665404d3644 ) diff --git a/sm2/sm2.go b/sm2/sm2.go index 20fda07..05278b3 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -7,7 +7,6 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/sha512" - "encoding/asn1" "encoding/binary" "errors" "fmt" @@ -17,6 +16,8 @@ import ( "sync" "github.com/emmansun/gmsm/sm3" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" ) const ( @@ -63,11 +64,6 @@ type EncrypterOpts struct { PointMarshalMode pointMarshalMode } -// Signer SM2 special signer -type Signer interface { - SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error) -} - func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { switch mode { case MarshalCompressed: @@ -81,6 +77,38 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed} +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed. +var directSigning crypto.Hash = 0 + +// Signer SM2 special signer +type Signer interface { + SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error) +} + +type SM2SignerOption struct { + UID []byte + ForceGMSign bool +} + +// NewSM2SignerOption create a SM2 specific signer option +// forceGMSign - if use GM specific sign logic, if yes, should pass raw message to sign +// uid - if forceGMSign is true, then you can pass uid, if no uid is provided, system will use default one +func NewSM2SignerOption(forceGMSign bool, uid []byte) *SM2SignerOption { + opt := &SM2SignerOption{ + UID: uid, + ForceGMSign: forceGMSign, + } + if forceGMSign && len(uid) == 0 { + opt.UID = defaultUID + } + return opt +} + +func (*SM2SignerOption) HashFunc() crypto.Hash { + return directSigning +} + // FromECPrivateKey convert an ecdsa private key to SM2 private key func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { if key.Curve != P256() { @@ -98,22 +126,27 @@ func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, er // where the private part is kept in, for example, a hardware module. Common // uses should use the Sign function in this package directly. func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - r, s, err := Sign(rand, &priv.PrivateKey, digest) + var r, s *big.Int + var err error + if sm2Opts, ok := opts.(*SM2SignerOption); ok && sm2Opts.ForceGMSign { + r, s, err = SignWithSM2(rand, &priv.PrivateKey, sm2Opts.UID, digest) + } else { + r, s, err = Sign(rand, &priv.PrivateKey, digest) + } if err != nil { return nil, err } - - return asn1.Marshal(ecdsaSignature{r, s}) + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + return b.Bytes() } // SignWithSM2 signs uid, msg with SignWithSM2 method. func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error) { - r, s, err := SignWithSM2(rand, &priv.PrivateKey, uid, msg) - if err != nil { - return nil, err - } - - return asn1.Marshal(ecdsaSignature{r, s}) + return priv.Sign(rand, msg, NewSM2SignerOption(true, uid)) } // Decrypt decrypts msg. The opts argument should be appropriate for @@ -464,6 +497,15 @@ func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s return Sign(rand, priv, md.Sum(nil)) } +// SignASN1 signs a hash (which should be the result of hashing a larger message) +// using the private key, priv. If the hash is longer than the bit-length of the +// private key's curve order, the hash will be truncated to that length. It +// returns the ASN.1 encoded signature. The security of the private key +// depends on the entropy of rand. +func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte, opts crypto.SignerOpts) ([]byte, error) { + return priv.Sign(rand, hash, opts) +} + // Verify verifies the signature in r, s of hash using the public key, pub. Its // return value records whether the signature is valid. func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { @@ -500,6 +542,24 @@ func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { return ecdsa.Verify(pub, hash, r, s) } +// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the +// public key, pub. Its return value records whether the signature is valid. +func VerifyASN1(pub *ecdsa.PublicKey, hash, sig []byte) bool { + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + return false + } + return Verify(pub, hash, r, s) +} + // VerifyWithSM2 verifies the signature in r, s of hash using the public key, pub. Its // return value records whether the signature is valid. func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { @@ -516,6 +576,24 @@ func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { return Verify(pub, md.Sum(nil), r, s) } +// VerifyASN1WithSM2 verifies the signature in r, s of hash using the public key, pub. Its +// return value records whether the signature is valid. +func VerifyASN1WithSM2(pub *ecdsa.PublicKey, uid, msg, sig []byte) bool { + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { + return false + } + return VerifyWithSM2(pub, uid, msg, r, s) +} + type zr struct { io.Reader } diff --git a/smx509/x509.go b/smx509/x509.go index 580c827..e8e53d0 100644 --- a/smx509/x509.go +++ b/smx509/x509.go @@ -177,10 +177,6 @@ type dsaAlgorithmParameters struct { P, Q, G *big.Int } -type ecdsaSignature struct { - R, S *big.Int -} - type validity struct { NotBefore, NotAfter time.Time } @@ -576,17 +572,8 @@ func (c *Certificate) CheckSignature(algo x509.SignatureAlgorithm, signed, signa if key.Curve != sm2.P256() { return c.Certificate.CheckSignature(algo, signed, signature) } - ecdsaSig := new(ecdsaSignature) - if rest, err := asn1.Unmarshal(signature, ecdsaSig); err != nil { - return err - } else if len(rest) != 0 { - return errors.New("x509: trailing data after ECDSA signature") - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - return errors.New("x509: ECDSA signature contained zero or negative values") - } - if !sm2.VerifyWithSM2(key, nil, signed, ecdsaSig.R, ecdsaSig.S) { - return errors.New("x509: ECDSA verification failure") + if !sm2.VerifyASN1WithSM2(key, nil, signed, signature) { + return errors.New("x509: SM2 verification failure") } return nil } @@ -2306,17 +2293,8 @@ func (c *CertificateRequest) CheckSignature() error { // a crypto.PublicKey. func checkSignature(c *x509.CertificateRequest, publicKey *ecdsa.PublicKey) (err error) { signed := c.RawTBSCertificateRequest - ecdsaSig := new(ecdsaSignature) - if rest, err := asn1.Unmarshal(c.Signature, ecdsaSig); err != nil { - return err - } else if len(rest) != 0 { - return errors.New("x509: trailing data after ECDSA signature") - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - return errors.New("x509: ECDSA signature contained zero or negative values") - } - if !sm2.VerifyWithSM2(publicKey, nil, signed, ecdsaSig.R, ecdsaSig.S) { - return errors.New("x509: ECDSA verification failure") + if !sm2.VerifyASN1WithSM2(publicKey, nil, signed, c.Signature) { + return errors.New("x509: SM2 verification failure") } return } diff --git a/smx509/x509_test.go b/smx509/x509_test.go index 40728c8..6cd4711 100644 --- a/smx509/x509_test.go +++ b/smx509/x509_test.go @@ -11,7 +11,6 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/base64" - "encoding/hex" "encoding/json" "encoding/pem" "errors" @@ -237,24 +236,14 @@ func TestCreateCertificateRequest(t *testing.T) { } func TestSignByAliVerifyAtLocal(t *testing.T) { - var rs = &ecdsaSignature{} dig, err := base64.StdEncoding.DecodeString(signature) if err != nil { t.Fatal(err) } - rest, err := asn1.Unmarshal(dig, rs) - if err != nil { - t.Fatal(err) - } - if len(rest) != 0 { - t.Errorf("rest len=%d", len(rest)) - } - - fmt.Printf("r=%s, s=%s\n", hex.EncodeToString(rs.R.Bytes()), hex.EncodeToString(rs.S.Bytes())) pub, err := getPublicKey([]byte(publicKeyPemFromAliKmsForSign)) pub1 := pub.(*ecdsa.PublicKey) hashValue, _ := base64.StdEncoding.DecodeString(hashBase64) - result := sm2.Verify(pub1, hashValue, rs.R, rs.S) + result := sm2.VerifyASN1(pub1, hashValue, dig) if !result { t.Error("Verify fail") }