package sm2 import ( "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "errors" "fmt" "io" "math/big" "net" "net/url" "strconv" "strings" "github.com/emmansun/gmsm/sm3" ) // pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo // in RFC 3280. type pkixPublicKey struct { Algo pkix.AlgorithmIdentifier BitString asn1.BitString } type publicKeyInfo struct { Raw asn1.RawContent Algorithm pkix.AlgorithmIdentifier PublicKey asn1.BitString } // pkcs1PublicKey reflects the ASN.1 structure of a PKCS#1 public key. type pkcs1PublicKey struct { N *big.Int E int } type dsaSignature struct { R, S *big.Int } type ecdsaSignature dsaSignature // http://gmssl.org/docs/oid.html var ( oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} oidNamedCurveP256SM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} oidSignatureSM2WithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} oidSignatureSM2WithSHA1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 502} oidSignatureSM2WithSHA256 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 503} ) func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { switch curve { case elliptic.P224(): return oidNamedCurveP224, true case elliptic.P256(): return oidNamedCurveP256, true case elliptic.P384(): return oidNamedCurveP384, true case elliptic.P521(): return oidNamedCurveP521, true case P256(): return oidNamedCurveP256SM2, true } return nil, false } func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { switch { case oid.Equal(oidNamedCurveP224): return elliptic.P224() case oid.Equal(oidNamedCurveP256): return elliptic.P256() case oid.Equal(oidNamedCurveP384): return elliptic.P384() case oid.Equal(oidNamedCurveP521): return elliptic.P521() case oid.Equal(oidNamedCurveP256SM2): return P256() } return nil } // ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. // // It returns a *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, or // ed25519.PublicKey. More types might be supported in the future. // // This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY". func ParsePKIXPublicKey(derBytes []byte) (interface{}, error) { var pki publicKeyInfo if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil { if _, err := asn1.Unmarshal(derBytes, &pkcs1PublicKey{}); err == nil { return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)") } return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after ASN.1 of public-key") } if !pki.Algorithm.Algorithm.Equal(oidPublicKeyECDSA) { return x509.ParsePKIXPublicKey(derBytes) } keyData := &pki asn1Data := keyData.PublicKey.RightAlign() paramsData := keyData.Algorithm.Parameters.FullBytes namedCurveOID := new(asn1.ObjectIdentifier) rest, err := asn1.Unmarshal(paramsData, namedCurveOID) if err != nil { return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") } if len(rest) != 0 { return nil, errors.New("x509: trailing data after ECDSA parameters") } if !namedCurveOID.Equal(oidNamedCurveP256SM2) { return x509.ParsePKIXPublicKey(derBytes) } namedCurve := P256() x, y := elliptic.Unmarshal(namedCurve, asn1Data) if x == nil { return nil, errors.New("x509: failed to unmarshal elliptic curve point") } pub := &ecdsa.PublicKey{ Curve: namedCurve, X: x, Y: y, } return pub, nil } // MarshalPKIXPublicKey converts a public key to PKIX, ASN.1 DER form. // // The following key types are currently supported: *rsa.PublicKey, *ecdsa.PublicKey // and ed25519.PublicKey. Unsupported key types result in an error. // // This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY". func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { ecdPub, ok := pub.(*ecdsa.PublicKey) if !ok || ecdPub.Curve != P256() { return x509.MarshalPKIXPublicKey(pub) } publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(ecdPub) if err != nil { return nil, err } pkix := pkixPublicKey{ Algo: publicKeyAlgorithm, BitString: asn1.BitString{ Bytes: publicKeyBytes, BitLength: 8 * len(publicKeyBytes), }, } return asn1.Marshal(pkix) } func marshalPublicKey(pub *ecdsa.PublicKey) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { publicKeyAlgorithm = pkix.AlgorithmIdentifier{Algorithm: oidPublicKeyECDSA} publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) paramBytes, err := asn1.Marshal(oidNamedCurveP256SM2) publicKeyAlgorithm.Parameters.FullBytes = paramBytes return } // CreateCertificateRequest creates a new certificate request based on a // template. The following members of template are used: // // - SignatureAlgorithm // - Subject // - DNSNames // - EmailAddresses // - IPAddresses // - URIs // - ExtraExtensions // - Attributes (deprecated) // // priv is the private key to sign the CSR with, and the corresponding public // key will be included in the CSR. It must implement crypto.Signer and its // Public() method must return a *rsa.PublicKey or a *ecdsa.PublicKey or a // ed25519.PublicKey. (A *rsa.PrivateKey, *ecdsa.PrivateKey or // ed25519.PrivateKey satisfies this.) // // The returned slice is the certificate request in DER encoding. func CreateCertificateRequest(rand io.Reader, template *x509.CertificateRequest, priv interface{}) (csr []byte, err error) { key, ok := priv.(crypto.Signer) if !ok { return nil, errors.New("x509: certificate private key does not implement crypto.Signer") } privKey, ok := key.(*PrivateKey) if !ok { return x509.CreateCertificateRequest(rand, template, priv) } var sigAlgo = pkix.AlgorithmIdentifier{} sigAlgo.Algorithm = oidSignatureSM2WithSM3 var publicKeyBytes []byte var publicKeyAlgorithm pkix.AlgorithmIdentifier publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(key.Public().(*ecdsa.PublicKey)) if err != nil { return nil, err } var extensions []pkix.Extension if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) && !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs) if err != nil { return nil, err } extensions = append(extensions, pkix.Extension{ Id: oidExtensionSubjectAltName, Value: sanBytes, }) } extensions = append(extensions, template.ExtraExtensions...) // Make a copy of template.Attributes because we may alter it below. attributes := make([]pkix.AttributeTypeAndValueSET, 0, len(template.Attributes)) for _, attr := range template.Attributes { values := make([][]pkix.AttributeTypeAndValue, len(attr.Value)) copy(values, attr.Value) attributes = append(attributes, pkix.AttributeTypeAndValueSET{ Type: attr.Type, Value: values, }) } extensionsAppended := false if len(extensions) > 0 { // Append the extensions to an existing attribute if possible. for _, atvSet := range attributes { if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 { continue } // specifiedExtensions contains all the extensions that we // found specified via template.Attributes. specifiedExtensions := make(map[string]bool) for _, atvs := range atvSet.Value { for _, atv := range atvs { specifiedExtensions[atv.Type.String()] = true } } newValue := make([]pkix.AttributeTypeAndValue, 0, len(atvSet.Value[0])+len(extensions)) newValue = append(newValue, atvSet.Value[0]...) for _, e := range extensions { if specifiedExtensions[e.Id.String()] { // Attributes already contained a value for // this extension and it takes priority. continue } newValue = append(newValue, pkix.AttributeTypeAndValue{ // There is no place for the critical // flag in an AttributeTypeAndValue. Type: e.Id, Value: e.Value, }) } atvSet.Value[0] = newValue extensionsAppended = true break } } rawAttributes, err := newRawAttributes(attributes) if err != nil { return } // If not included in attributes, add a new attribute for the // extensions. if len(extensions) > 0 && !extensionsAppended { attr := struct { Type asn1.ObjectIdentifier Value [][]pkix.Extension `asn1:"set"` }{ Type: oidExtensionRequest, Value: [][]pkix.Extension{extensions}, } b, err := asn1.Marshal(attr) if err != nil { return nil, errors.New("x509: failed to serialise extensions attribute: " + err.Error()) } var rawValue asn1.RawValue if _, err := asn1.Unmarshal(b, &rawValue); err != nil { return nil, err } rawAttributes = append(rawAttributes, rawValue) } asn1Subject := template.RawSubject if len(asn1Subject) == 0 { asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence()) if err != nil { return nil, err } } tbsCSR := tbsCertificateRequest{ Version: 0, // PKCS #10, RFC 2986 Subject: asn1.RawValue{FullBytes: asn1Subject}, PublicKey: publicKeyInfo{ Algorithm: publicKeyAlgorithm, PublicKey: asn1.BitString{ Bytes: publicKeyBytes, BitLength: len(publicKeyBytes) * 8, }, }, RawAttributes: rawAttributes, } tbsCSRContents, err := asn1.Marshal(tbsCSR) if err != nil { return } tbsCSR.Raw = tbsCSRContents signed := tbsCSRContents za, err := CalculateZA(&privKey.PublicKey, defaultUID) //Emman, use template.Subject as UID? if err != nil { return } h := sm3.New() h.Write(za) h.Write(signed) signed = h.Sum(nil) var signature []byte signature, err = privKey.Sign(rand, signed, nil) if err != nil { return } return asn1.Marshal(certificateRequest{ TBSCSR: tbsCSR, SignatureAlgorithm: sigAlgo, SignatureValue: asn1.BitString{ Bytes: signature, BitLength: len(signature) * 8, }, }) } // These structures reflect the ASN.1 structure of X.509 certificate // signature requests (see RFC 2986): type tbsCertificateRequest struct { Raw asn1.RawContent Version int Subject asn1.RawValue PublicKey publicKeyInfo RawAttributes []asn1.RawValue `asn1:"tag:0"` } type certificateRequest struct { Raw asn1.RawContent TBSCSR tbsCertificateRequest SignatureAlgorithm pkix.AlgorithmIdentifier SignatureValue asn1.BitString } // oidExtensionRequest is a PKCS#9 OBJECT IDENTIFIER that indicates requested // extensions in a CSR. var oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14} var ( oidExtensionSubjectKeyId = []int{2, 5, 29, 14} oidExtensionKeyUsage = []int{2, 5, 29, 15} oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37} oidExtensionAuthorityKeyId = []int{2, 5, 29, 35} oidExtensionBasicConstraints = []int{2, 5, 29, 19} oidExtensionSubjectAltName = []int{2, 5, 29, 17} oidExtensionCertificatePolicies = []int{2, 5, 29, 32} oidExtensionNameConstraints = []int{2, 5, 29, 30} oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31} oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} ) // newRawAttributes converts AttributeTypeAndValueSETs from a template // CertificateRequest's Attributes into tbsCertificateRequest RawAttributes. func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) { var rawAttributes []asn1.RawValue b, err := asn1.Marshal(attributes) if err != nil { return nil, err } rest, err := asn1.Unmarshal(b, &rawAttributes) if err != nil { return nil, err } if len(rest) != 0 { return nil, errors.New("x509: failed to unmarshal raw CSR Attributes") } return rawAttributes, nil } // oidNotInExtensions reports whether an extension with the given oid exists in // extensions. func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { for _, e := range extensions { if e.Id.Equal(oid) { return true } } return false } const ( nameTypeEmail = 1 nameTypeDNS = 2 nameTypeURI = 6 nameTypeIP = 7 ) // marshalSANs marshals a list of addresses into a the contents of an X.509 // SubjectAlternativeName extension. func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL) (derBytes []byte, err error) { var rawValues []asn1.RawValue for _, name := range dnsNames { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeDNS, Class: 2, Bytes: []byte(name)}) } for _, email := range emailAddresses { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeEmail, Class: 2, Bytes: []byte(email)}) } for _, rawIP := range ipAddresses { // If possible, we always want to encode IPv4 addresses in 4 bytes. ip := rawIP.To4() if ip == nil { ip = rawIP } rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeIP, Class: 2, Bytes: ip}) } for _, uri := range uris { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeURI, Class: 2, Bytes: []byte(uri.String())}) } return asn1.Marshal(rawValues) } // ParseCertificateRequest parses a single certificate request from the // given ASN.1 DER data. func ParseCertificateRequest(asn1Data []byte) (*x509.CertificateRequest, error) { var csr certificateRequest rest, err := asn1.Unmarshal(asn1Data, &csr) if err != nil { return nil, err } else if len(rest) != 0 { return nil, asn1.SyntaxError{Msg: "trailing data"} } if !csr.SignatureAlgorithm.Algorithm.Equal(oidSignatureSM2WithSM3) { return x509.ParseCertificateRequest(asn1Data) } return parseCertificateRequest(&csr) } func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, error) { if !oidSignatureSM2WithSM3.Equal(in.SignatureAlgorithm.Algorithm) { return nil, errors.New("unsupport signature algorithm") } out := &x509.CertificateRequest{ Raw: in.Raw, RawTBSCertificateRequest: in.TBSCSR.Raw, RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw, RawSubject: in.TBSCSR.Subject.FullBytes, Signature: in.SignatureValue.RightAlign(), PublicKeyAlgorithm: getPublicKeyAlgorithmFromOID(in.TBSCSR.PublicKey.Algorithm.Algorithm), Version: in.TBSCSR.Version, Attributes: parseRawAttributes(in.TBSCSR.RawAttributes), } var err error out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey) if err != nil { return nil, err } var subject pkix.RDNSequence if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 Subject") } out.Subject.FillFromRDNSequence(&subject) if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil { return nil, err } for _, extension := range out.Extensions { if extension.Id.Equal(oidExtensionSubjectAltName) { out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(extension.Value) if err != nil { return nil, err } } } return out, nil } func parsePublicKey(keyData *publicKeyInfo) (interface{}, error) { asn1Data := keyData.PublicKey.RightAlign() paramsData := keyData.Algorithm.Parameters.FullBytes namedCurveOID := new(asn1.ObjectIdentifier) rest, err := asn1.Unmarshal(paramsData, namedCurveOID) if err != nil { return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") } if len(rest) != 0 { return nil, errors.New("x509: trailing data after ECDSA parameters") } namedCurve := namedCurveFromOID(*namedCurveOID) if namedCurve == nil { return nil, errors.New("x509: unsupported elliptic curve") } x, y := elliptic.Unmarshal(namedCurve, asn1Data) if x == nil { return nil, errors.New("x509: failed to unmarshal elliptic curve point") } pub := &ecdsa.PublicKey{ Curve: namedCurve, X: x, Y: y, } return pub, nil } // parseRawAttributes Unmarshals RawAttributes into AttributeTypeAndValueSETs. func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET { var attributes []pkix.AttributeTypeAndValueSET for _, rawAttr := range rawAttributes { var attr pkix.AttributeTypeAndValueSET rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr) // Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET // (i.e.: challengePassword or unstructuredName). if err == nil && len(rest) == 0 { attributes = append(attributes, attr) } } return attributes } // parseCSRExtensions parses the attributes from a CSR and extracts any // requested extensions. func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) { // pkcs10Attribute reflects the Attribute structure from RFC 2986, Section 4.1. type pkcs10Attribute struct { Id asn1.ObjectIdentifier Values []asn1.RawValue `asn1:"set"` } var ret []pkix.Extension for _, rawAttr := range rawAttributes { var attr pkcs10Attribute if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 { // Ignore attributes that don't parse. continue } if !attr.Id.Equal(oidExtensionRequest) { continue } var extensions []pkix.Extension if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil { return nil, err } ret = append(ret, extensions...) } return ret, nil } func forEachSAN(extension []byte, callback func(tag int, data []byte) error) error { // RFC 5280, 4.2.1.6 // SubjectAltName ::= GeneralNames // // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName // // GeneralName ::= CHOICE { // otherName [0] OtherName, // rfc822Name [1] IA5String, // dNSName [2] IA5String, // x400Address [3] ORAddress, // directoryName [4] Name, // ediPartyName [5] EDIPartyName, // uniformResourceIdentifier [6] IA5String, // iPAddress [7] OCTET STRING, // registeredID [8] OBJECT IDENTIFIER } var seq asn1.RawValue rest, err := asn1.Unmarshal(extension, &seq) if err != nil { return err } else if len(rest) != 0 { return errors.New("x509: trailing data after X.509 extension") } if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { return asn1.StructuralError{Msg: "bad SAN sequence"} } rest = seq.Bytes for len(rest) > 0 { var v asn1.RawValue rest, err = asn1.Unmarshal(rest, &v) if err != nil { return err } if err := callback(v.Tag, v.Bytes); err != nil { return err } } return nil } // domainToReverseLabels converts a textual domain name like foo.example.com to // the list of labels in reverse order, e.g. ["com", "example", "foo"]. func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { for len(domain) > 0 { if i := strings.LastIndexByte(domain, '.'); i == -1 { reverseLabels = append(reverseLabels, domain) domain = "" } else { reverseLabels = append(reverseLabels, domain[i+1:]) domain = domain[:i] } } if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 { // An empty label at the end indicates an absolute value. return nil, false } for _, label := range reverseLabels { if len(label) == 0 { // Empty labels are otherwise invalid. return nil, false } for _, c := range label { if c < 33 || c > 126 { // Invalid character. return nil, false } } } return reverseLabels, true } func parseSANExtension(value []byte) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) { err = forEachSAN(value, func(tag int, data []byte) error { switch tag { case nameTypeEmail: emailAddresses = append(emailAddresses, string(data)) case nameTypeDNS: dnsNames = append(dnsNames, string(data)) case nameTypeURI: uri, err := url.Parse(string(data)) if err != nil { return fmt.Errorf("x509: cannot parse URI %q: %s", string(data), err) } if len(uri.Host) > 0 { if _, ok := domainToReverseLabels(uri.Host); !ok { return fmt.Errorf("x509: cannot parse URI %q: invalid domain", string(data)) } } uris = append(uris, uri) case nameTypeIP: switch len(data) { case net.IPv4len, net.IPv6len: ipAddresses = append(ipAddresses, data) default: return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data))) } } return nil }) return } // RFC 3279, 2.3 Public Key Algorithms // // pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) // rsadsi(113549) pkcs(1) 1 } // // rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 } // // id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) // x9-57(10040) x9cm(4) 1 } // // RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters // // id-ecPublicKey OBJECT IDENTIFIER ::= { // iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } var ( oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} ) func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) x509.PublicKeyAlgorithm { if oid.Equal(oidPublicKeyECDSA) { return x509.ECDSA } return x509.UnknownPublicKeyAlgorithm } // CheckSignature reports whether the signature on c is valid. func CheckSignature(c *x509.CertificateRequest) error { if c.PublicKeyAlgorithm == x509.ECDSA { pub, ok := c.PublicKey.(*ecdsa.PublicKey) if ok && strings.EqualFold(P256().Params().Name, pub.Curve.Params().Name) { return checkSignature(c, pub) } } return c.CheckSignature() } // CheckSignature verifies that signature is a valid signature over signed from // a crypto.PublicKey. func checkSignature(c *x509.CertificateRequest, publicKey *ecdsa.PublicKey) (err error) { signed := c.RawTBSCertificateRequest ecdsaSig := new(ecdsaSignature) if rest, err := asn1.Unmarshal(c.Signature, ecdsaSig); err != nil { return err } else if len(rest) != 0 { return errors.New("x509: trailing data after ECDSA signature") } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errors.New("x509: ECDSA signature contained zero or negative values") } if !VerifyWithSM2(publicKey, nil, signed, ecdsaSig.R, ecdsaSig.S) { return errors.New("x509: ECDSA verification failure") } return }