From cf0c739dcf373daa359bf07ee5ff7ef66c4be587 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 3 Feb 2023 10:25:03 +0800 Subject: [PATCH] smx509: change CreateCertificate's template and parent parameter type to any --- smx509/name_constraints_test.go | 2 +- smx509/verify_test.go | 4 +-- smx509/x509.go | 62 +++++++++++++++++++++++---------- smx509/x509_additional_test.go | 49 ++++++++++++++++++++++++++ smx509/x509_test.go | 8 ++--- 5 files changed, 99 insertions(+), 26 deletions(-) diff --git a/smx509/name_constraints_test.go b/smx509/name_constraints_test.go index cf7df0b..92a7255 100644 --- a/smx509/name_constraints_test.go +++ b/smx509/name_constraints_test.go @@ -1624,7 +1624,7 @@ func makeConstraintsCACert(constraints constraintsSpec, name string, key *ecdsa. if parent == nil { parent = template } - derBytes, err := CreateCertificate(rand.Reader, template.asX509(), parent.asX509(), &key.PublicKey, parentKey) + derBytes, err := CreateCertificate(rand.Reader, template, parent, &key.PublicKey, parentKey) if err != nil { return nil, err } diff --git a/smx509/verify_test.go b/smx509/verify_test.go index 0adc191..61ddca5 100644 --- a/smx509/verify_test.go +++ b/smx509/verify_test.go @@ -1952,7 +1952,7 @@ func genCertEdge(t *testing.T, subject string, key crypto.Signer, mutateTmpl fun signer = key } - d, err := CreateCertificate(rand.Reader, tmpl.asX509(), issuer.asX509(), key.Public(), signer) + d, err := CreateCertificate(rand.Reader, tmpl, issuer, key.Public(), signer) if err != nil { t.Fatalf("failed to generate test cert: %s", err) } @@ -2598,7 +2598,7 @@ func TestVerifyEKURootAsLeaf(t *testing.T) { DNSNames: []string{"localhost"}, ExtKeyUsage: tc.rootEKUs, } - rootDER, err := CreateCertificate(rand.Reader, tmpl.asX509(), tmpl.asX509(), k.Public(), k) + rootDER, err := CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k) if err != nil { t.Fatalf("failed to create certificate: %s", err) } diff --git a/smx509/x509.go b/smx509/x509.go index c293615..de3739f 100644 --- a/smx509/x509.go +++ b/smx509/x509.go @@ -658,6 +658,7 @@ func (c *Certificate) asX509() *x509.Certificate { return (*x509.Certificate)(c) } +// ToX509 convert smx509.Certificate reference to x509.Certificate func (c *Certificate) ToX509() *x509.Certificate { return c.asX509() } @@ -1374,8 +1375,9 @@ var emptyASN1Subject = []byte{0x30, 0} // - UnknownExtKeyUsage // // The certificate is signed by parent. If parent is equal to template then the -// certificate is self-signed. The parameter pub is the public key of the -// certificate to be generated and priv is the private key of the signer. +// certificate is self-signed, both parent and template should be *x509.Certificate or +// *smx509.Certificate type. The parameter pub is the public key of the certificate to +// be generated and priv is the private key of the signer. // // The returned slice is the certificate in DER encoding. // @@ -1389,13 +1391,23 @@ var emptyASN1Subject = []byte{0x30, 0} // // If SubjectKeyId from template is empty and the template is a CA, SubjectKeyId // will be generated from the hash of the public key. -func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, priv interface{}) ([]byte, error) { +func CreateCertificate(rand io.Reader, template, parent, pub, priv interface{}) ([]byte, error) { + realTemplate, err := toCertificate(template) + if err != nil { + return nil, fmt.Errorf("x509: unsupported template parameter type: %T", template) + } + + realParent, err := toCertificate(parent) + if err != nil { + return nil, fmt.Errorf("x509: unsupported parent parameter type: %T", parent) + } + key, ok := priv.(crypto.Signer) if !ok { return nil, errors.New("x509: certificate private key does not implement crypto.Signer") } - if template.SerialNumber == nil { + if realTemplate.SerialNumber == nil { return nil, errors.New("x509: no SerialNumber given") } @@ -1404,15 +1416,15 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, // We _should_ also restrict serials to <= 20 octets, but it turns out a lot of people // get this wrong, in part because the encoding can itself alter the length of the // serial. For now we accept these non-conformant serials. - if template.SerialNumber.Sign() == -1 { + if realTemplate.SerialNumber.Sign() == -1 { return nil, errors.New("x509: serial number must be positive") } - if template.BasicConstraintsValid && !template.IsCA && template.MaxPathLen != -1 && (template.MaxPathLen != 0 || template.MaxPathLenZero) { + if realTemplate.BasicConstraintsValid && !realTemplate.IsCA && realTemplate.MaxPathLen != -1 && (realTemplate.MaxPathLen != 0 || realTemplate.MaxPathLenZero) { return nil, errors.New("x509: only CAs are allowed to specify MaxPathLen") } - hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm) + hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), realTemplate.SignatureAlgorithm) if err != nil { return nil, err } @@ -1426,23 +1438,23 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, return nil, fmt.Errorf("x509: unsupported public key type: %T", pub) } - asn1Issuer, err := subjectBytes(parent) + asn1Issuer, err := subjectBytes(realParent) if err != nil { return nil, err } - asn1Subject, err := subjectBytes(template) + asn1Subject, err := subjectBytes(realTemplate) if err != nil { return nil, err } - authorityKeyId := template.AuthorityKeyId - if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 { - authorityKeyId = parent.SubjectKeyId + authorityKeyId := realTemplate.AuthorityKeyId + if !bytes.Equal(asn1Issuer, asn1Subject) && len(realParent.SubjectKeyId) > 0 { + authorityKeyId = realParent.SubjectKeyId } - subjectKeyId := template.SubjectKeyId - if len(subjectKeyId) == 0 && template.IsCA { + subjectKeyId := realTemplate.SubjectKeyId + if len(subjectKeyId) == 0 && realTemplate.IsCA { // SubjectKeyId generated using method 1 in RFC 5280, Section 4.2.1.2: // (1) The keyIdentifier is composed of the 160-bit SHA-1 hash of the // value of the BIT STRING subjectPublicKey (excluding the tag, @@ -1458,11 +1470,11 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, if privPub, ok := key.Public().(privateKey); !ok { return nil, errors.New("x509: internal error: supported public key does not implement Equal") - } else if parent.PublicKey != nil && !privPub.Equal(parent.PublicKey) { + } else if realParent.PublicKey != nil && !privPub.Equal(realParent.PublicKey) { return nil, errors.New("x509: provided PrivateKey doesn't match parent's PublicKey") } - extensions, err := buildCertExtensions(template, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId, subjectKeyId) + extensions, err := buildCertExtensions(realTemplate, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId, subjectKeyId) if err != nil { return nil, err } @@ -1470,10 +1482,10 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, encodedPublicKey := asn1.BitString{BitLength: len(publicKeyBytes) * 8, Bytes: publicKeyBytes} c := tbsCertificate{ Version: 2, - SerialNumber: template.SerialNumber, + SerialNumber: realTemplate.SerialNumber, SignatureAlgorithm: signatureAlgorithm, Issuer: asn1.RawValue{FullBytes: asn1Issuer}, - Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()}, + Validity: validity{realTemplate.NotBefore.UTC(), realTemplate.NotAfter.UTC()}, Subject: asn1.RawValue{FullBytes: asn1Subject}, PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey}, Extensions: extensions, @@ -1495,7 +1507,7 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, } var signerOpts crypto.SignerOpts = hashFunc - if template.SignatureAlgorithm != 0 && isRSAPSS(template.SignatureAlgorithm) { + if realTemplate.SignatureAlgorithm != 0 && isRSAPSS(realTemplate.SignatureAlgorithm) { signerOpts = &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc, @@ -1525,6 +1537,17 @@ func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, return signedCert, nil } +func toCertificate(in interface{}) (*x509.Certificate, error) { + switch c := in.(type) { + case *x509.Certificate: + return c, nil + case *Certificate: + return c.asX509(), nil + default: + return nil, fmt.Errorf("unsupported certificate of type %T", in) + } +} + // ParseCRL parses a CRL from the given bytes. It's often the case that PEM // encoded CRLs will appear where they should be DER encoded, so this function // will transparently handle PEM encoding as long as there isn't any leading @@ -1615,6 +1638,7 @@ func (c *CertificateRequest) asX509() *x509.CertificateRequest { return (*x509.CertificateRequest)(c) } +// ToX509 convert smx509.CertificateRequest reference to x509.CertificateRequest func (c *CertificateRequest) ToX509() *x509.CertificateRequest { return c.asX509() } diff --git a/smx509/x509_additional_test.go b/smx509/x509_additional_test.go index 43d2b20..34c0a5c 100644 --- a/smx509/x509_additional_test.go +++ b/smx509/x509_additional_test.go @@ -213,3 +213,52 @@ func TestMarshalECDHPKIXPublicKey(t *testing.T) { t.Fatal("should be same") } } + +func TestToCertificate(t *testing.T) { + x509Cert := new(x509.Certificate) + + c, err := toCertificate(x509Cert) + if err != nil || c != x509Cert { + t.Fatal("should be no error") + } + + smX509Cert := new(Certificate) + _, err = toCertificate(smX509Cert) + if err != nil { + t.Fatal("should be no error") + } + + _, err = toCertificate("test") + if err == nil { + t.Fatal("should be error") + } + + _, err = toCertificate(nil) + if err == nil { + t.Fatal("should be error") + } +} + +func TestInvalidParentTemplate(t *testing.T) { + random := rand.Reader + + sm2Priv, err := sm2.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate SM2 key: %s", err) + } + _, err = CreateCertificate(random, nil, nil, sm2Priv.PublicKey, sm2Priv) + if err == nil { + t.Fatal("should be error") + } + if err.Error() != "x509: unsupported template parameter type: " { + t.Fatalf("unexpected error message: %v", err.Error()) + } + + _, err = CreateCertificate(random, new(x509.Certificate), nil, sm2Priv.PublicKey, sm2Priv) + if err == nil { + t.Fatal("should be error") + } + if err.Error() != "x509: unsupported parent parameter type: " { + t.Fatalf("unexpected error message: %v", err.Error()) + } +} diff --git a/smx509/x509_test.go b/smx509/x509_test.go index bcca044..32edc15 100644 --- a/smx509/x509_test.go +++ b/smx509/x509_test.go @@ -2747,7 +2747,7 @@ func TestCreateCertificateLegacy(t *testing.T) { DNSNames: []string{"example.com"}, SignatureAlgorithm: sigAlg, } - _, err := CreateCertificate(rand.Reader, template.asX509(), template.asX509(), testPrivateKey.Public(), &brokenSigner{testPrivateKey.Public()}) + _, err := CreateCertificate(rand.Reader, template, template, testPrivateKey.Public(), &brokenSigner{testPrivateKey.Public()}) if err == nil { t.Fatal("CreateCertificate didn't fail when SignatureAlgorithm = MD5WithRSA") } @@ -3085,7 +3085,7 @@ func TestDisableSHA1ForCertOnly(t *testing.T) { IsCA: true, KeyUsage: KeyUsageCertSign | KeyUsageCRLSign, } - certDER, err := CreateCertificate(rand.Reader, tmpl.asX509(), tmpl.asX509(), rsaPrivateKey.Public(), rsaPrivateKey) + certDER, err := CreateCertificate(rand.Reader, tmpl, tmpl, rsaPrivateKey.Public(), rsaPrivateKey) if err != nil { t.Fatalf("failed to generate test cert: %s", err) } @@ -3148,7 +3148,7 @@ func TestOmitEmptyExtensions(t *testing.T) { NotAfter: time.Now().Add(time.Hour), NotBefore: time.Now().Add(-time.Hour), } - der, err := CreateCertificate(rand.Reader, tmpl.asX509(), tmpl.asX509(), k.Public(), k) + der, err := CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k) if err != nil { t.Fatal(err) } @@ -3190,7 +3190,7 @@ func TestCreateNegativeSerial(t *testing.T) { NotBefore: time.Now().Add(-time.Hour), } expectedErr := "x509: serial number must be positive" - _, err = CreateCertificate(rand.Reader, tmpl.asX509(), tmpl.asX509(), k.Public(), k) + _, err = CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k) if err == nil || err.Error() != expectedErr { t.Errorf("CreateCertificate returned unexpected error: want %q, got %q", expectedErr, err) }