diff --git a/go.mod b/go.mod index 9d099b9..95e67d9 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/emmansun/gmsm go 1.14 + +require golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad diff --git a/sm2/p256_asm.go b/sm2/p256_asm.go index c9bd82e..1e26bb8 100644 --- a/sm2/p256_asm.go +++ b/sm2/p256_asm.go @@ -446,7 +446,7 @@ func boothW6(in uint) (int, int) { return int(d), int(s & 1) } -// table[i][j] = (2^(6*i))*j*G mod P +// table[i][j] = (2^(6*i))*(j+1)*G mod P func initTable() { p256Precomputed = new([43][32 * 8]uint64) diff --git a/sm2/sm2.go b/sm2/sm2.go index 7175463..2e2db1b 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -38,7 +38,7 @@ type combinedMult interface { CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int) } -// PrivateKey represents an ECDSA private key. +// PrivateKey represents an ECDSA SM2 private key. type PrivateKey struct { ecdsa.PrivateKey } @@ -76,6 +76,15 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed} +// FromECPrivateKey convert an ecdsa private key to SM2 private key +func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { + if key.Curve != P256() { + return nil, errors.New("It's NOT a sm2 curve private key") + } + priv.PrivateKey = *key + return priv, nil +} + // Sign signs digest with priv, reading randomness from rand. The opts argument // is not currently used but, in keeping with the crypto.Signer interface, // should be the hash function used to digest the message. diff --git a/smx509/cert_pool.go b/smx509/cert_pool.go new file mode 100644 index 0000000..a53f718 --- /dev/null +++ b/smx509/cert_pool.go @@ -0,0 +1,155 @@ +package smx509 + +import ( + "encoding/pem" + "errors" + "runtime" +) + +// CertPool is a set of certificates. +type CertPool struct { + bySubjectKeyId map[string][]int + byName map[string][]int + certs []*Certificate +} + +// NewCertPool returns a new, empty CertPool. +func NewCertPool() *CertPool { + return &CertPool{ + bySubjectKeyId: make(map[string][]int), + byName: make(map[string][]int), + } +} + +func (s *CertPool) copy() *CertPool { + p := &CertPool{ + bySubjectKeyId: make(map[string][]int, len(s.bySubjectKeyId)), + byName: make(map[string][]int, len(s.byName)), + certs: make([]*Certificate, len(s.certs)), + } + for k, v := range s.bySubjectKeyId { + indexes := make([]int, len(v)) + copy(indexes, v) + p.bySubjectKeyId[k] = indexes + } + for k, v := range s.byName { + indexes := make([]int, len(v)) + copy(indexes, v) + p.byName[k] = indexes + } + copy(p.certs, s.certs) + return p +} + +// SystemCertPool returns a copy of the system cert pool. +// +// Any mutations to the returned pool are not written to disk and do +// not affect any other pool returned by SystemCertPool. +// +// New changes in the system cert pool might not be reflected +// in subsequent calls. +func SystemCertPool() (*CertPool, error) { + if runtime.GOOS == "windows" { + // Issue 16736, 18609: + return nil, errors.New("crypto/x509: system root pool is not available on Windows") + } + + if sysRoots := systemRootsPool(); sysRoots != nil { + return sysRoots.copy(), nil + } + + return loadSystemRoots() +} + +// findPotentialParents returns the indexes of certificates in s which might +// have signed cert. The caller must not modify the returned slice. +func (s *CertPool) findPotentialParents(cert *Certificate) []int { + if s == nil { + return nil + } + + var candidates []int + if len(cert.AuthorityKeyId) > 0 { + candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } + if len(candidates) == 0 { + candidates = s.byName[string(cert.RawIssuer)] + } + return candidates +} + +func (s *CertPool) contains(cert *Certificate) bool { + if s == nil { + return false + } + + candidates := s.byName[string(cert.RawSubject)] + for _, c := range candidates { + if s.certs[c].Equal(cert) { + return true + } + } + + return false +} + +// AddCert adds a certificate to a pool. +func (s *CertPool) AddCert(cert *Certificate) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + + // Check that the certificate isn't being added twice. + if s.contains(cert) { + return + } + + n := len(s.certs) + s.certs = append(s.certs, cert) + + if len(cert.SubjectKeyId) > 0 { + keyId := string(cert.SubjectKeyId) + s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + } + name := string(cert.RawSubject) + s.byName[name] = append(s.byName[name], n) +} + +// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates. +// It appends any certificates found to s and reports whether any certificates +// were successfully parsed. +// +// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set +// of root CAs in a format suitable for this function. +func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := ParseCertificate(block.Bytes) + if err != nil { + continue + } + + s.AddCert(cert) + ok = true + } + + return +} + +// Subjects returns a list of the DER-encoded subjects of +// all of the certificates in the pool. +func (s *CertPool) Subjects() [][]byte { + res := make([][]byte, len(s.certs)) + for i, c := range s.certs { + res[i] = c.RawSubject + } + return res +} diff --git a/smx509/pkcs8.go b/smx509/pkcs8.go new file mode 100644 index 0000000..13e5861 --- /dev/null +++ b/smx509/pkcs8.go @@ -0,0 +1,99 @@ +package smx509 + +import ( + "crypto/ecdsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + + "github.com/emmansun/gmsm/sm2" +) + +// pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See +// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn +// and RFC 5208. +type pkcs8 struct { + Version int + Algo pkix.AlgorithmIdentifier + PrivateKey []byte + // optional attributes omitted. +} + +// ParsePKCS8PrivateKey parses an unencrypted private key in PKCS#8, ASN.1 DER form. +// +// It returns a *rsa.PrivateKey, a *ecdsa.PrivateKey, or a ed25519.PrivateKey. +// More types might be supported in the future. +// +// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY". +func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) { + var privKey pkcs8 + if _, err := asn1.Unmarshal(der, &privKey); err != nil { + if _, err := asn1.Unmarshal(der, &ecPrivateKey{}); err == nil { + return nil, errors.New("x509: failed to parse private key (use ParseECPrivateKey instead for this key format)") + } + if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil { + return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)") + } + return nil, err + } + if !privKey.Algo.Algorithm.Equal(oidPublicKeyECDSA) { + return x509.ParsePKCS8PrivateKey(der) + } + bytes := privKey.Algo.Parameters.FullBytes + namedCurveOID := new(asn1.ObjectIdentifier) + if _, err := asn1.Unmarshal(bytes, namedCurveOID); err != nil { + namedCurveOID = nil + } + ecKey, err := parseECPrivateKey(namedCurveOID, privKey.PrivateKey) + if err != nil { + return nil, errors.New("x509: failed to parse EC private key embedded in PKCS#8: " + err.Error()) + } + if namedCurveOID.Equal(oidNamedCurveP256SM2) { + key, err = new(sm2.PrivateKey).FromECPrivateKey(ecKey) + } else { + key = ecKey + } + return key, nil +} + +// MarshalPKCS8PrivateKey converts a private key to PKCS#8, ASN.1 DER form. +// +// The following key types are currently supported: *rsa.PrivateKey, *ecdsa.PrivateKey +// and ed25519.PrivateKey. Unsupported key types result in an error. +// +// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY". +func MarshalPKCS8PrivateKey(key interface{}) ([]byte, error) { + switch k := key.(type) { + case *ecdsa.PrivateKey: + return marshalPKCS8ECPrivateKey(k) + case *sm2.PrivateKey: + return marshalPKCS8ECPrivateKey(&k.PrivateKey) + } + return x509.MarshalPKCS8PrivateKey(key) +} + +func marshalPKCS8ECPrivateKey(k *ecdsa.PrivateKey) ([]byte, error) { + var privKey pkcs8 + oid, ok := oidFromNamedCurve(k.Curve) + if !ok { + return nil, errors.New("x509: unknown curve while marshaling to PKCS#8") + } + + oidBytes, err := asn1.Marshal(oid) + if err != nil { + return nil, errors.New("x509: failed to marshal curve OID: " + err.Error()) + } + + privKey.Algo = pkix.AlgorithmIdentifier{ + Algorithm: oidPublicKeyECDSA, + Parameters: asn1.RawValue{ + FullBytes: oidBytes, + }, + } + + if privKey.PrivateKey, err = marshalECPrivateKeyWithOID(k, nil); err != nil { + return nil, errors.New("x509: failed to marshal EC private key while building PKCS#8: " + err.Error()) + } + return asn1.Marshal(privKey) +} diff --git a/smx509/pkcs8_test.go b/smx509/pkcs8_test.go new file mode 100644 index 0000000..46c4491 --- /dev/null +++ b/smx509/pkcs8_test.go @@ -0,0 +1,157 @@ +package smx509 + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/hex" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/emmansun/gmsm/sm2" +) + +// Generated using: +// openssl genrsa 1024 | openssl pkcs8 -topk8 -nocrypt +var pkcs8RSAPrivateKeyHex = `30820278020100300d06092a864886f70d0101010500048202623082025e02010002818100cfb1b5bf9685ffa97b4f99df4ff122b70e59ac9b992f3bc2b3dde17d53c1a34928719b02e8fd17839499bfbd515bd6ef99c7a1c47a239718fe36bfd824c0d96060084b5f67f0273443007a24dfaf5634f7772c9346e10eb294c2306671a5a5e719ae24b4de467291bc571014b0e02dec04534d66a9bb171d644b66b091780e8d020301000102818100b595778383c4afdbab95d2bfed12b3f93bb0a73a7ad952f44d7185fd9ec6c34de8f03a48770f2009c8580bcd275e9632714e9a5e3f32f29dc55474b2329ff0ebc08b3ffcb35bc96e6516b483df80a4a59cceb71918cbabf91564e64a39d7e35dce21cb3031824fdbc845dba6458852ec16af5dddf51a8397a8797ae0337b1439024100ea0eb1b914158c70db39031dd8904d6f18f408c85fbbc592d7d20dee7986969efbda081fdf8bc40e1b1336d6b638110c836bfdc3f314560d2e49cd4fbde1e20b024100e32a4e793b574c9c4a94c8803db5152141e72d03de64e54ef2c8ed104988ca780cd11397bc359630d01b97ebd87067c5451ba777cf045ca23f5912f1031308c702406dfcdbbd5a57c9f85abc4edf9e9e29153507b07ce0a7ef6f52e60dcfebe1b8341babd8b789a837485da6c8d55b29bbb142ace3c24a1f5b54b454d01b51e2ad03024100bd6a2b60dee01e1b3bfcef6a2f09ed027c273cdbbaf6ba55a80f6dcc64e4509ee560f84b4f3e076bd03b11e42fe71a3fdd2dffe7e0902c8584f8cad877cdc945024100aa512fa4ada69881f1d8bb8ad6614f192b83200aef5edf4811313d5ef30a86cbd0a90f7b025c71ea06ec6b34db6306c86b1040670fd8654ad7291d066d06d031` + +// Generated using: +// openssl ecparam -genkey -name secp224r1 | openssl pkcs8 -topk8 -nocrypt +var pkcs8P224PrivateKeyHex = `3078020100301006072a8648ce3d020106052b810400210461305f020101041cca3d72b3e88fed2684576dad9b80a9180363a5424986900e3abcab3fa13c033a0004f8f2a6372872a4e61263ed893afb919576a4cacfecd6c081a2cbc76873cf4ba8530703c6042b3a00e2205087e87d2435d2e339e25702fae1` + +// Generated using: +// openssl ecparam -genkey -name secp256r1 | openssl pkcs8 -topk8 -nocrypt +var pkcs8P256PrivateKeyHex = `308187020100301306072a8648ce3d020106082a8648ce3d030107046d306b0201010420dad6b2f49ca774c36d8ae9517e935226f667c929498f0343d2424d0b9b591b43a14403420004b9c9b90095476afe7b860d8bd43568cab7bcb2eed7b8bf2fa0ce1762dd20b04193f859d2d782b1e4cbfd48492f1f533113a6804903f292258513837f07fda735` + +var pkcs8SM2P256PrivateKeyHex = `308187020100301306072a8648ce3d020106082a811ccf5501822d046d306b0201010420b26da57ba53004ddcd387ad46a361b51b308481f2327d47fb10c5fb3a8c86b92a144034200040d5365bfdbdc564c5b0eda0a85ddbd753821a709de90efe0666ba2544766acf1100ac0484d166842011da5cd6139e53dedb99ce37cea9edf4941628066e861bf` + +// Generated using: +// openssl ecparam -genkey -name secp384r1 | openssl pkcs8 -topk8 -nocrypt +var pkcs8P384PrivateKeyHex = `3081b6020100301006072a8648ce3d020106052b8104002204819e30819b02010104309bf832f6aaaeacb78ce47ffb15e6fd0fd48683ae79df6eca39bfb8e33829ac94aa29d08911568684c2264a08a4ceb679a164036200049070ad4ed993c7770d700e9f6dc2baa83f63dd165b5507f98e8ff29b5d2e78ccbe05c8ddc955dbf0f7497e8222cfa49314fe4e269459f8e880147f70d785e530f2939e4bf9f838325bb1a80ad4cf59272ae0e5efe9a9dc33d874492596304bd3` + +// Generated using: +// openssl ecparam -genkey -name secp521r1 | openssl pkcs8 -topk8 -nocrypt +// +// Note that OpenSSL will truncate the private key if it can (i.e. it emits it +// like an integer, even though it's an OCTET STRING field). Thus if you +// regenerate this you may, randomly, find that it's a byte shorter than +// expected and the Go test will fail to recreate it exactly. +var pkcs8P521PrivateKeyHex = `3081ee020100301006072a8648ce3d020106052b810400230481d63081d3020101044200cfe0b87113a205cf291bb9a8cd1a74ac6c7b2ebb8199aaa9a5010d8b8012276fa3c22ac913369fa61beec2a3b8b4516bc049bde4fb3b745ac11b56ab23ac52e361a1818903818600040138f75acdd03fbafa4f047a8e4b272ba9d555c667962b76f6f232911a5786a0964e5edea6bd21a6f8725720958de049c6e3e6661c1c91b227cebee916c0319ed6ca003db0a3206d372229baf9dd25d868bf81140a518114803ce40c1855074d68c4e9dab9e65efba7064c703b400f1767f217dac82715ac1f6d88c74baf47a7971de4ea` + +// From RFC 8410, Section 7. +var pkcs8Ed25519PrivateKeyHex = `302e020100300506032b657004220420d4ee72dbf913584ad5b6d8f1f769f8ad3afe7c28cbf1d4fbe097a88f44755842` + +func TestPKCS8(t *testing.T) { + tests := []struct { + name string + keyHex string + keyType reflect.Type + curve elliptic.Curve + }{ + { + name: "RSA private key", + keyHex: pkcs8RSAPrivateKeyHex, + keyType: reflect.TypeOf(&rsa.PrivateKey{}), + }, + { + name: "P-224 private key", + keyHex: pkcs8P224PrivateKeyHex, + keyType: reflect.TypeOf(&ecdsa.PrivateKey{}), + curve: elliptic.P224(), + }, + { + name: "P-256 private key", + keyHex: pkcs8P256PrivateKeyHex, + keyType: reflect.TypeOf(&ecdsa.PrivateKey{}), + curve: elliptic.P256(), + }, + { + name: "SM2 P-256 private key", + keyHex: pkcs8SM2P256PrivateKeyHex, + keyType: reflect.TypeOf(&sm2.PrivateKey{}), + curve: sm2.P256(), + }, + { + name: "P-384 private key", + keyHex: pkcs8P384PrivateKeyHex, + keyType: reflect.TypeOf(&ecdsa.PrivateKey{}), + curve: elliptic.P384(), + }, + { + name: "P-521 private key", + keyHex: pkcs8P521PrivateKeyHex, + keyType: reflect.TypeOf(&ecdsa.PrivateKey{}), + curve: elliptic.P521(), + }, + { + name: "Ed25519 private key", + keyHex: pkcs8Ed25519PrivateKeyHex, + keyType: reflect.TypeOf(ed25519.PrivateKey{}), + }, + } + + for _, test := range tests { + derBytes, err := hex.DecodeString(test.keyHex) + if err != nil { + t.Errorf("%s: failed to decode hex: %s", test.name, err) + continue + } + privKey, err := ParsePKCS8PrivateKey(derBytes) + if err != nil { + t.Errorf("%s: failed to decode PKCS#8: %s", test.name, err) + continue + } + if reflect.TypeOf(privKey) != test.keyType { + t.Errorf("%s: decoded PKCS#8 returned unexpected key type: %T", test.name, privKey) + continue + } + if ecKey, isEC := privKey.(*ecdsa.PrivateKey); isEC && ecKey.Curve != test.curve { + t.Errorf("%s: decoded PKCS#8 returned unexpected curve %#v", test.name, ecKey.Curve) + continue + } + reserialised, err := MarshalPKCS8PrivateKey(privKey) + if err != nil { + t.Errorf("%s: failed to marshal into PKCS#8: %s", test.name, err) + continue + } + if !bytes.Equal(derBytes, reserialised) { + t.Errorf("%s: marshaled PKCS#8 didn't match original: got %x, want %x", test.name, reserialised, derBytes) + continue + } + } +} + +const hexPKCS8TestPKCS1Key = "3082025c02010002818100b1a1e0945b9289c4d3f1329f8a982c4a2dcd59bfd372fb8085a9c517554607ebd2f7990eef216ac9f4605f71a03b04f42a5255b158cf8e0844191f5119348baa44c35056e20609bcf9510f30ead4b481c81d7865fb27b8e0090e112b717f3ee08cdfc4012da1f1f7cf2a1bc34c73a54a12b06372d09714742dd7895eadde4aa5020301000102818062b7fa1db93e993e40237de4d89b7591cc1ea1d04fed4904c643f17ae4334557b4295270d0491c161cb02a9af557978b32b20b59c267a721c4e6c956c2d147046e9ae5f2da36db0106d70021fa9343455f8f973a4b355a26fd19e6b39dee0405ea2b32deddf0f4817759ef705d02b34faab9ca93c6766e9f722290f119f34449024100d9c29a4a013a90e35fd1be14a3f747c589fac613a695282d61812a711906b8a0876c6181f0333ca1066596f57bff47e7cfcabf19c0fc69d9cd76df743038b3cb024100d0d3546fecf879b5551f2bd2c05e6385f2718a08a6face3d2aecc9d7e03645a480a46c81662c12ad6bd6901e3bd4f38029462de7290859567cdf371c79088d4f024100c254150657e460ea58573fcf01a82a4791e3d6223135c8bdfed69afe84fbe7857274f8eb5165180507455f9b4105c6b08b51fe8a481bb986a202245576b713530240045700003b7a867d0041df9547ae2e7f50248febd21c9040b12dae9c2feab0d3d4609668b208e4727a3541557f84d372ac68eaf74ce1018a4c9a0ef92682c8fd02405769731480bb3a4570abf422527c5f34bf732fa6c1e08cc322753c511ce055fac20fc770025663ad3165324314df907f1f1942f0448a7e9cdbf87ecd98b92156" +const hexPKCS8TestECKey = "3081a40201010430bdb9839c08ee793d1157886a7a758a3c8b2a17a4df48f17ace57c72c56b4723cf21dcda21d4e1ad57ff034f19fcfd98ea00706052b81040022a16403620004feea808b5ee2429cfcce13c32160e1c960990bd050bb0fdf7222f3decd0a55008e32a6aa3c9062051c4cba92a7a3b178b24567412d43cdd2f882fa5addddd726fe3e208d2c26d733a773a597abb749714df7256ead5105fa6e7b3650de236b50" + +var pkcs8MismatchKeyTests = []struct { + hexKey string + errorContains string +}{ + {hexKey: hexPKCS8TestECKey, errorContains: "use ParseECPrivateKey instead"}, + {hexKey: hexPKCS8TestPKCS1Key, errorContains: "use ParsePKCS1PrivateKey instead"}, +} + +func TestPKCS8MismatchKeyFormat(t *testing.T) { + for i, test := range pkcs8MismatchKeyTests { + derBytes, _ := hex.DecodeString(test.hexKey) + _, err := ParsePKCS8PrivateKey(derBytes) + if !strings.Contains(err.Error(), test.errorContains) { + t.Errorf("#%d: expected error containing %q, got %s", i, test.errorContains, err) + } + } +} + +func TestMarshalPKCS8SM2PrivateKey(t *testing.T) { + priv, _ := sm2.GenerateKey(rand.Reader) + res, err := MarshalPKCS8PrivateKey(priv) + if err != nil { + t.Fatalf("%v\n", err) + } + fmt.Printf("%s\n", hex.EncodeToString(res)) +} diff --git a/smx509/root.go b/smx509/root.go new file mode 100644 index 0000000..f779bee --- /dev/null +++ b/smx509/root.go @@ -0,0 +1,21 @@ +package smx509 + +import "sync" + +var ( + once sync.Once + systemRoots *CertPool + systemRootsErr error +) + +func systemRootsPool() *CertPool { + once.Do(initSystemRoots) + return systemRoots +} + +func initSystemRoots() { + systemRoots, systemRootsErr = loadSystemRoots() + if systemRootsErr != nil { + systemRoots = nil + } +} diff --git a/smx509/root_aix.go b/smx509/root_aix.go new file mode 100644 index 0000000..ecac5c9 --- /dev/null +++ b/smx509/root_aix.go @@ -0,0 +1,6 @@ +package smx509 + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/var/ssl/certs/ca-bundle.crt", +} diff --git a/smx509/root_bsd.go b/smx509/root_bsd.go new file mode 100644 index 0000000..55a9ea9 --- /dev/null +++ b/smx509/root_bsd.go @@ -0,0 +1,11 @@ +// +build dragonfly freebsd netbsd openbsd + +package smx509 + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/usr/local/etc/ssl/cert.pem", // FreeBSD + "/etc/ssl/cert.pem", // OpenBSD + "/usr/local/share/certs/ca-root-nss.crt", // DragonFly + "/etc/openssl/certs/ca-certificates.crt", // NetBSD +} diff --git a/smx509/root_js.go b/smx509/root_js.go new file mode 100644 index 0000000..148fd40 --- /dev/null +++ b/smx509/root_js.go @@ -0,0 +1,6 @@ +// +build js,wasm + +package smx509 + +// Possible certificate files; stop after finding one. +var certFiles = []string{} diff --git a/smx509/root_linux.go b/smx509/root_linux.go new file mode 100644 index 0000000..a0c2287 --- /dev/null +++ b/smx509/root_linux.go @@ -0,0 +1,11 @@ +package smx509 + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc. + "/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL 6 + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7 + "/etc/ssl/cert.pem", // Alpine Linux +} diff --git a/smx509/root_plan9.go b/smx509/root_plan9.go new file mode 100644 index 0000000..11c7487 --- /dev/null +++ b/smx509/root_plan9.go @@ -0,0 +1,36 @@ +// +build plan9 + +package smx509 + +import ( + "io/ioutil" + "os" +) + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/sys/lib/tls/ca.pem", +} + +func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) { + return nil, nil +} + +func loadSystemRoots() (*CertPool, error) { + roots := NewCertPool() + var bestErr error + for _, file := range certFiles { + data, err := ioutil.ReadFile(file) + if err == nil { + roots.AppendCertsFromPEM(data) + return roots, nil + } + if bestErr == nil || (os.IsNotExist(bestErr) && !os.IsNotExist(err)) { + bestErr = err + } + } + if bestErr == nil { + return roots, nil + } + return nil, bestErr +} diff --git a/smx509/root_solaris.go b/smx509/root_solaris.go new file mode 100644 index 0000000..97c22ef --- /dev/null +++ b/smx509/root_solaris.go @@ -0,0 +1,8 @@ +package smx509 + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/etc/certs/ca-certificates.crt", // Solaris 11.2+ + "/etc/ssl/certs/ca-certificates.crt", // Joyent SmartOS + "/etc/ssl/cacert.pem", // OmniOS +} diff --git a/smx509/root_unix.go b/smx509/root_unix.go new file mode 100644 index 0000000..a6cbeeb --- /dev/null +++ b/smx509/root_unix.go @@ -0,0 +1,85 @@ +// +build aix dragonfly freebsd js,wasm linux netbsd openbsd solaris + +package smx509 + +import ( + "io/ioutil" + "os" +) + +// Possible directories with certificate files; stop after successfully +// reading at least one file from a directory. +var certDirectories = []string{ + "/etc/ssl/certs", // SLES10/SLES11, https://golang.org/issue/12139 + "/system/etc/security/cacerts", // Android + "/usr/local/share/certs", // FreeBSD + "/etc/pki/tls/certs", // Fedora/RHEL + "/etc/openssl/certs", // NetBSD + "/var/ssl/certs", // AIX +} + +const ( + // certFileEnv is the environment variable which identifies where to locate + // the SSL certificate file. If set this overrides the system default. + certFileEnv = "SSL_CERT_FILE" + + // certDirEnv is the environment variable which identifies which directory + // to check for SSL certificate files. If set this overrides the system default. + certDirEnv = "SSL_CERT_DIR" +) + +func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) { + return nil, nil +} + +func loadSystemRoots() (*CertPool, error) { + roots := NewCertPool() + + files := certFiles + if f := os.Getenv(certFileEnv); f != "" { + files = []string{f} + } + + var firstErr error + for _, file := range files { + data, err := ioutil.ReadFile(file) + if err == nil { + roots.AppendCertsFromPEM(data) + break + } + if firstErr == nil && !os.IsNotExist(err) { + firstErr = err + } + } + + dirs := certDirectories + if d := os.Getenv(certDirEnv); d != "" { + dirs = []string{d} + } + + for _, directory := range dirs { + fis, err := ioutil.ReadDir(directory) + if err != nil { + if firstErr == nil && !os.IsNotExist(err) { + firstErr = err + } + continue + } + rootsAdded := false + for _, fi := range fis { + data, err := ioutil.ReadFile(directory + "/" + fi.Name()) + if err == nil && roots.AppendCertsFromPEM(data) { + rootsAdded = true + } + } + if rootsAdded { + return roots, nil + } + } + + if len(roots.certs) > 0 || firstErr == nil { + return roots, nil + } + + return nil, firstErr +} diff --git a/smx509/root_windows.go b/smx509/root_windows.go new file mode 100644 index 0000000..5aa56a7 --- /dev/null +++ b/smx509/root_windows.go @@ -0,0 +1,303 @@ +package smx509 + +import ( + "crypto/x509" + "errors" + "syscall" + "unsafe" +) + +// Creates a new *syscall.CertContext representing the leaf certificate in an in-memory +// certificate store containing itself and all of the intermediate certificates specified +// in the opts.Intermediates CertPool. +// +// A pointer to the in-memory store is available in the returned CertContext's Store field. +// The store is automatically freed when the CertContext is freed using +// syscall.CertFreeCertificateContext. +func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertContext, error) { + var storeCtx *syscall.CertContext + + leafCtx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &leaf.Raw[0], uint32(len(leaf.Raw))) + if err != nil { + return nil, err + } + defer syscall.CertFreeCertificateContext(leafCtx) + + handle, err := syscall.CertOpenStore(syscall.CERT_STORE_PROV_MEMORY, 0, 0, syscall.CERT_STORE_DEFER_CLOSE_UNTIL_LAST_FREE_FLAG, 0) + if err != nil { + return nil, err + } + defer syscall.CertCloseStore(handle, 0) + + err = syscall.CertAddCertificateContextToStore(handle, leafCtx, syscall.CERT_STORE_ADD_ALWAYS, &storeCtx) + if err != nil { + return nil, err + } + + if opts.Intermediates != nil { + for _, intermediate := range opts.Intermediates.certs { + ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw))) + if err != nil { + return nil, err + } + + err = syscall.CertAddCertificateContextToStore(handle, ctx, syscall.CERT_STORE_ADD_ALWAYS, nil) + syscall.CertFreeCertificateContext(ctx) + if err != nil { + return nil, err + } + } + } + + return storeCtx, nil +} + +// extractSimpleChain extracts the final certificate chain from a CertSimpleChain. +func extractSimpleChain(simpleChain **syscall.CertSimpleChain, count int) (chain []*Certificate, err error) { + if simpleChain == nil || count == 0 { + return nil, errors.New("x509: invalid simple chain") + } + + simpleChains := (*[1 << 20]*syscall.CertSimpleChain)(unsafe.Pointer(simpleChain))[:count:count] + lastChain := simpleChains[count-1] + elements := (*[1 << 20]*syscall.CertChainElement)(unsafe.Pointer(lastChain.Elements))[:lastChain.NumElements:lastChain.NumElements] + for i := 0; i < int(lastChain.NumElements); i++ { + // Copy the buf, since ParseCertificate does not create its own copy. + cert := elements[i].CertContext + encodedCert := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:cert.Length:cert.Length] + buf := make([]byte, cert.Length) + copy(buf, encodedCert) + parsedCert, err := ParseCertificate(buf) + if err != nil { + return nil, err + } + chain = append(chain, parsedCert) + } + + return chain, nil +} + +// checkChainTrustStatus checks the trust status of the certificate chain, translating +// any errors it finds into Go errors in the process. +func checkChainTrustStatus(c *Certificate, chainCtx *syscall.CertChainContext) error { + if chainCtx.TrustStatus.ErrorStatus != syscall.CERT_TRUST_NO_ERROR { + status := chainCtx.TrustStatus.ErrorStatus + switch status { + case syscall.CERT_TRUST_IS_NOT_TIME_VALID: + return x509.CertificateInvalidError{&c.Certificate, x509.Expired, ""} + case syscall.CERT_TRUST_IS_NOT_VALID_FOR_USAGE: + return x509.CertificateInvalidError{&c.Certificate, x509.IncompatibleUsage, ""} + // TODO(filippo): surface more error statuses. + default: + return UnknownAuthorityError{c, nil, nil} + } + } + return nil +} + +// checkChainSSLServerPolicy checks that the certificate chain in chainCtx is valid for +// use as a certificate chain for a SSL/TLS server. +func checkChainSSLServerPolicy(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) error { + servernamep, err := syscall.UTF16PtrFromString(opts.DNSName) + if err != nil { + return err + } + sslPara := &syscall.SSLExtraCertChainPolicyPara{ + AuthType: syscall.AUTHTYPE_SERVER, + ServerName: servernamep, + } + sslPara.Size = uint32(unsafe.Sizeof(*sslPara)) + + para := &syscall.CertChainPolicyPara{ + ExtraPolicyPara: (syscall.Pointer)(unsafe.Pointer(sslPara)), + } + para.Size = uint32(unsafe.Sizeof(*para)) + + status := syscall.CertChainPolicyStatus{} + err = syscall.CertVerifyCertificateChainPolicy(syscall.CERT_CHAIN_POLICY_SSL, chainCtx, para, &status) + if err != nil { + return err + } + + // TODO(mkrautz): use the lChainIndex and lElementIndex fields + // of the CertChainPolicyStatus to provide proper context, instead + // using c. + if status.Error != 0 { + switch status.Error { + case syscall.CERT_E_EXPIRED: + return x509.CertificateInvalidError{&c.Certificate, x509.Expired, ""} + case syscall.CERT_E_CN_NO_MATCH: + return x509.HostnameError{&c.Certificate, opts.DNSName} + case syscall.CERT_E_UNTRUSTEDROOT: + return UnknownAuthorityError{c, nil, nil} + default: + return UnknownAuthorityError{c, nil, nil} + } + } + + return nil +} + +// windowsExtKeyUsageOIDs are the C NUL-terminated string representations of the +// OIDs for use with the Windows API. +var windowsExtKeyUsageOIDs = make(map[x509.ExtKeyUsage][]byte, len(extKeyUsageOIDs)) + +func init() { + for _, eku := range extKeyUsageOIDs { + windowsExtKeyUsageOIDs[eku.extKeyUsage] = []byte(eku.oid.String() + "\x00") + } +} + +// systemVerify is like Verify, except that it uses CryptoAPI calls +// to build certificate chains and verify them. +func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) { + storeCtx, err := createStoreContext(c, opts) + if err != nil { + return nil, err + } + defer syscall.CertFreeCertificateContext(storeCtx) + + para := new(syscall.CertChainPara) + para.Size = uint32(unsafe.Sizeof(*para)) + + keyUsages := opts.KeyUsages + if len(keyUsages) == 0 { + keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + } + oids := make([]*byte, 0, len(keyUsages)) + for _, eku := range keyUsages { + if eku == x509.ExtKeyUsageAny { + oids = nil + break + } + if oid, ok := windowsExtKeyUsageOIDs[eku]; ok { + oids = append(oids, &oid[0]) + } + // Like the standard verifier, accept SGC EKUs as equivalent to ServerAuth. + if eku == x509.ExtKeyUsageServerAuth { + oids = append(oids, &syscall.OID_SERVER_GATED_CRYPTO[0]) + oids = append(oids, &syscall.OID_SGC_NETSCAPE[0]) + } + } + if oids != nil { + para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_OR + para.RequestedUsage.Usage.Length = uint32(len(oids)) + para.RequestedUsage.Usage.UsageIdentifiers = &oids[0] + } else { + para.RequestedUsage.Type = syscall.USAGE_MATCH_TYPE_AND + para.RequestedUsage.Usage.Length = 0 + para.RequestedUsage.Usage.UsageIdentifiers = nil + } + + var verifyTime *syscall.Filetime + if opts != nil && !opts.CurrentTime.IsZero() { + ft := syscall.NsecToFiletime(opts.CurrentTime.UnixNano()) + verifyTime = &ft + } + + // CertGetCertificateChain will traverse Windows's root stores + // in an attempt to build a verified certificate chain. Once + // it has found a verified chain, it stops. MSDN docs on + // CERT_CHAIN_CONTEXT: + // + // When a CERT_CHAIN_CONTEXT is built, the first simple chain + // begins with an end certificate and ends with a self-signed + // certificate. If that self-signed certificate is not a root + // or otherwise trusted certificate, an attempt is made to + // build a new chain. CTLs are used to create the new chain + // beginning with the self-signed certificate from the original + // chain as the end certificate of the new chain. This process + // continues building additional simple chains until the first + // self-signed certificate is a trusted certificate or until + // an additional simple chain cannot be built. + // + // The result is that we'll only get a single trusted chain to + // return to our caller. + var chainCtx *syscall.CertChainContext + err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, 0, 0, &chainCtx) + if err != nil { + return nil, err + } + defer syscall.CertFreeCertificateChain(chainCtx) + + err = checkChainTrustStatus(c, chainCtx) + if err != nil { + return nil, err + } + + if opts != nil && len(opts.DNSName) > 0 { + err = checkChainSSLServerPolicy(c, chainCtx, opts) + if err != nil { + return nil, err + } + } + + chain, err := extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount)) + if err != nil { + return nil, err + } + if len(chain) < 1 { + return nil, errors.New("x509: internal error: system verifier returned an empty chain") + } + + // Mitigate CVE-2020-0601, where the Windows system verifier might be + // tricked into using custom curve parameters for a trusted root, by + // double-checking all ECDSA signatures. If the system was tricked into + // using spoofed parameters, the signature will be invalid for the correct + // ones we parsed. (We don't support custom curves ourselves.) + for i, parent := range chain[1:] { + if parent.PublicKeyAlgorithm != x509.ECDSA { + continue + } + if err := parent.CheckSignature(chain[i].SignatureAlgorithm, + chain[i].RawTBSCertificate, chain[i].Signature); err != nil { + return nil, err + } + } + + return [][]*Certificate{chain}, nil +} + +func loadSystemRoots() (*CertPool, error) { + // TODO: restore this functionality on Windows. We tried to do + // it in Go 1.8 but had to revert it. See Issue 18609. + // Returning (nil, nil) was the old behavior, prior to CL 30578. + // The if statement here avoids vet complaining about + // unreachable code below. + if true { + return nil, nil + } + + const CRYPT_E_NOT_FOUND = 0x80092004 + + store, err := syscall.CertOpenSystemStore(0, syscall.StringToUTF16Ptr("ROOT")) + if err != nil { + return nil, err + } + defer syscall.CertCloseStore(store, 0) + + roots := NewCertPool() + var cert *syscall.CertContext + for { + cert, err = syscall.CertEnumCertificatesInStore(store, cert) + if err != nil { + if errno, ok := err.(syscall.Errno); ok { + if errno == CRYPT_E_NOT_FOUND { + break + } + } + return nil, err + } + if cert == nil { + break + } + // Copy the buf, since ParseCertificate does not create its own copy. + buf := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:cert.Length:cert.Length] + buf2 := make([]byte, cert.Length) + copy(buf2, buf) + if c, err := ParseCertificate(buf2); err == nil { + roots.AddCert(c) + } + } + return roots, nil +} diff --git a/smx509/sec1.go b/smx509/sec1.go new file mode 100644 index 0000000..601980f --- /dev/null +++ b/smx509/sec1.go @@ -0,0 +1,159 @@ +package smx509 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "encoding/asn1" + "errors" + "fmt" + "math/big" + + "github.com/emmansun/gmsm/sm2" +) + +// pkcs1PrivateKey is a structure which mirrors the PKCS#1 ASN.1 for an RSA private key. +type pkcs1PrivateKey struct { + Version int + N *big.Int + E int + D *big.Int + P *big.Int + Q *big.Int + // We ignore these values, if present, because rsa will calculate them. + Dp *big.Int `asn1:"optional"` + Dq *big.Int `asn1:"optional"` + Qinv *big.Int `asn1:"optional"` + + AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"` +} + +type pkcs1AdditionalRSAPrime struct { + Prime *big.Int + + // We ignore these values because rsa will calculate them. + Exp *big.Int + Coeff *big.Int +} + +const ecPrivKeyVersion = 1 + +// ecPrivateKey reflects an ASN.1 Elliptic Curve Private Key Structure. +// References: +// RFC 5915 +// SEC1 - http://www.secg.org/sec1-v2.pdf +// Per RFC 5915 the NamedCurveOID is marked as ASN.1 OPTIONAL, however in +// most cases it is not. +type ecPrivateKey struct { + Version int + PrivateKey []byte + NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"` + PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"` +} + +// ParseECPrivateKey parses an EC private key in SEC 1, ASN.1 DER form. +// +// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY". +func ParseECPrivateKey(der []byte) (*ecdsa.PrivateKey, error) { + return parseECPrivateKey(nil, der) +} + +// ParseSM2PrivateKey parses an SM2 private key +func ParseSM2PrivateKey(der []byte) (*sm2.PrivateKey, error) { + key, err := parseECPrivateKey(nil, der) + if err != nil { + return nil, err + } + return new(sm2.PrivateKey).FromECPrivateKey(key) +} + +// MarshalECPrivateKey converts an EC private key to SEC 1, ASN.1 DER form. +// +// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY". +// For a more flexible key format which is not EC specific, use +// MarshalPKCS8PrivateKey. +func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) { + oid, ok := oidFromNamedCurve(key.Curve) + if !ok { + return nil, errors.New("x509: unknown elliptic curve") + } + + return marshalECPrivateKeyWithOID(key, oid) +} + +// MarshalSM2PrivateKey convient method to marshal sm2 private key directly +func MarshalSM2PrivateKey(key *sm2.PrivateKey) ([]byte, error) { + return MarshalECPrivateKey(&key.PrivateKey) +} + +// marshalECPrivateKey marshals an EC private key into ASN.1, DER format and +// sets the curve ID to the given OID, or omits it if OID is nil. +func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) { + privateKeyBytes := key.D.Bytes() + paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8) + copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes) + + return asn1.Marshal(ecPrivateKey{ + Version: 1, + PrivateKey: paddedPrivateKey, + NamedCurveOID: oid, + PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}, + }) +} + +// parseECPrivateKey parses an ASN.1 Elliptic Curve Private Key Structure. +// The OID for the named curve may be provided from another source (such as +// the PKCS8 container) - if it is provided then use this instead of the OID +// that may exist in the EC private key structure. +func parseECPrivateKey(namedCurveOID *asn1.ObjectIdentifier, der []byte) (key *ecdsa.PrivateKey, err error) { + var privKey ecPrivateKey + if _, err := asn1.Unmarshal(der, &privKey); err != nil { + if _, err := asn1.Unmarshal(der, &pkcs8{}); err == nil { + return nil, errors.New("x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)") + } + if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil { + return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)") + } + return nil, errors.New("x509: failed to parse EC private key: " + err.Error()) + } + if privKey.Version != ecPrivKeyVersion { + return nil, fmt.Errorf("x509: unknown EC private key version %d", privKey.Version) + } + + var curve elliptic.Curve + if namedCurveOID != nil { + curve = namedCurveFromOID(*namedCurveOID) + } else { + curve = namedCurveFromOID(privKey.NamedCurveOID) + } + if curve == nil { + return nil, errors.New("x509: unknown elliptic curve") + } + + k := new(big.Int).SetBytes(privKey.PrivateKey) + curveOrder := curve.Params().N + if k.Cmp(curveOrder) >= 0 { + return nil, errors.New("x509: invalid elliptic curve private key value") + } + priv := new(ecdsa.PrivateKey) + priv.Curve = curve + priv.D = k + + privateKey := make([]byte, (curveOrder.BitLen()+7)/8) + + // Some private keys have leading zero padding. This is invalid + // according to [SEC1], but this code will ignore it. + for len(privKey.PrivateKey) > len(privateKey) { + if privKey.PrivateKey[0] != 0 { + return nil, errors.New("x509: invalid private key length") + } + privKey.PrivateKey = privKey.PrivateKey[1:] + } + + // Some private keys remove all leading zeros, this is also invalid + // according to [SEC1] but since OpenSSL used to do this, we ignore + // this too. + copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey) + priv.X, priv.Y = curve.ScalarBaseMult(privateKey) + + return priv, nil +} diff --git a/smx509/sec1_test.go b/smx509/sec1_test.go new file mode 100644 index 0000000..36cac6e --- /dev/null +++ b/smx509/sec1_test.go @@ -0,0 +1,77 @@ +package smx509 + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + "testing" + + "github.com/emmansun/gmsm/sm2" +) + +var ecKeyTests = []struct { + derHex string + shouldReserialize bool +}{ + // Generated using: + // openssl ecparam -genkey -name secp384r1 -outform PEM + {"3081a40201010430bdb9839c08ee793d1157886a7a758a3c8b2a17a4df48f17ace57c72c56b4723cf21dcda21d4e1ad57ff034f19fcfd98ea00706052b81040022a16403620004feea808b5ee2429cfcce13c32160e1c960990bd050bb0fdf7222f3decd0a55008e32a6aa3c9062051c4cba92a7a3b178b24567412d43cdd2f882fa5addddd726fe3e208d2c26d733a773a597abb749714df7256ead5105fa6e7b3650de236b50", true}, + // Generated using MarshalSM2PrivateKey + {"30770201010420857dd87970aab4328dad891c781e3b270742aa9cf5d3d3764efe77f6c3d6e33aa00a06082a811ccf5501822da14403420004ced963a5705a0490ff13dde893cbda6de61f41fcaf917a5b4007d30cdec46426bc39b9c18d15b2a68a64dc333f262e600b675856285b42296f24741ee6f562a0", true}, + // This key was generated by GnuTLS and has illegal zero-padding of the + // private key. See https://golang.org/issues/13699. + {"3078020101042100f9f43a04b9bdc3ab01f53be6df80e7a7bc3eaf7b87fc24e630a4a0aa97633645a00a06082a8648ce3d030107a1440342000441a51bc318461b4c39a45048a16d4fc2a935b1ea7fe86e8c1fa219d6f2438f7c7fd62957d3442efb94b6a23eb0ea66dda663dc42f379cda6630b21b7888a5d3d", false}, + // This was generated using an old version of OpenSSL and is missing a + // leading zero byte in the private key that should be present. + {"3081db0201010441607b4f985774ac21e633999794542e09312073480baa69550914d6d43d8414441e61b36650567901da714f94dffb3ce0e2575c31928a0997d51df5c440e983ca17a00706052b81040023a181890381860004001661557afedd7ac8d6b70e038e576558c626eb62edda36d29c3a1310277c11f67a8c6f949e5430a37dcfb95d902c1b5b5379c389873b9dd17be3bdb088a4774a7401072f830fb9a08d93bfa50a03dd3292ea07928724ddb915d831917a338f6b0aecfbc3cf5352c4a1295d356890c41c34116d29eeb93779aab9d9d78e2613437740f6", false}, +} + +func TestParseECPrivateKey(t *testing.T) { + for i, test := range ecKeyTests { + derBytes, _ := hex.DecodeString(test.derHex) + key, err := ParseECPrivateKey(derBytes) + if err != nil { + t.Fatalf("#%d: failed to decode EC private key: %s", i, err) + } + serialized, err := MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("#%d: failed to encode EC private key: %s", i, err) + } + matches := bytes.Equal(serialized, derBytes) + if matches != test.shouldReserialize { + t.Fatalf("#%d: when serializing key: matches=%t, should match=%t: original %x, reserialized %x", i, matches, test.shouldReserialize, serialized, derBytes) + } + } +} + +const hexECTestPKCS1Key = "3082025c02010002818100b1a1e0945b9289c4d3f1329f8a982c4a2dcd59bfd372fb8085a9c517554607ebd2f7990eef216ac9f4605f71a03b04f42a5255b158cf8e0844191f5119348baa44c35056e20609bcf9510f30ead4b481c81d7865fb27b8e0090e112b717f3ee08cdfc4012da1f1f7cf2a1bc34c73a54a12b06372d09714742dd7895eadde4aa5020301000102818062b7fa1db93e993e40237de4d89b7591cc1ea1d04fed4904c643f17ae4334557b4295270d0491c161cb02a9af557978b32b20b59c267a721c4e6c956c2d147046e9ae5f2da36db0106d70021fa9343455f8f973a4b355a26fd19e6b39dee0405ea2b32deddf0f4817759ef705d02b34faab9ca93c6766e9f722290f119f34449024100d9c29a4a013a90e35fd1be14a3f747c589fac613a695282d61812a711906b8a0876c6181f0333ca1066596f57bff47e7cfcabf19c0fc69d9cd76df743038b3cb024100d0d3546fecf879b5551f2bd2c05e6385f2718a08a6face3d2aecc9d7e03645a480a46c81662c12ad6bd6901e3bd4f38029462de7290859567cdf371c79088d4f024100c254150657e460ea58573fcf01a82a4791e3d6223135c8bdfed69afe84fbe7857274f8eb5165180507455f9b4105c6b08b51fe8a481bb986a202245576b713530240045700003b7a867d0041df9547ae2e7f50248febd21c9040b12dae9c2feab0d3d4609668b208e4727a3541557f84d372ac68eaf74ce1018a4c9a0ef92682c8fd02405769731480bb3a4570abf422527c5f34bf732fa6c1e08cc322753c511ce055fac20fc770025663ad3165324314df907f1f1942f0448a7e9cdbf87ecd98b92156" +const hexECTestPKCS8Key = "30820278020100300d06092a864886f70d0101010500048202623082025e02010002818100cfb1b5bf9685ffa97b4f99df4ff122b70e59ac9b992f3bc2b3dde17d53c1a34928719b02e8fd17839499bfbd515bd6ef99c7a1c47a239718fe36bfd824c0d96060084b5f67f0273443007a24dfaf5634f7772c9346e10eb294c2306671a5a5e719ae24b4de467291bc571014b0e02dec04534d66a9bb171d644b66b091780e8d020301000102818100b595778383c4afdbab95d2bfed12b3f93bb0a73a7ad952f44d7185fd9ec6c34de8f03a48770f2009c8580bcd275e9632714e9a5e3f32f29dc55474b2329ff0ebc08b3ffcb35bc96e6516b483df80a4a59cceb71918cbabf91564e64a39d7e35dce21cb3031824fdbc845dba6458852ec16af5dddf51a8397a8797ae0337b1439024100ea0eb1b914158c70db39031dd8904d6f18f408c85fbbc592d7d20dee7986969efbda081fdf8bc40e1b1336d6b638110c836bfdc3f314560d2e49cd4fbde1e20b024100e32a4e793b574c9c4a94c8803db5152141e72d03de64e54ef2c8ed104988ca780cd11397bc359630d01b97ebd87067c5451ba777cf045ca23f5912f1031308c702406dfcdbbd5a57c9f85abc4edf9e9e29153507b07ce0a7ef6f52e60dcfebe1b8341babd8b789a837485da6c8d55b29bbb142ace3c24a1f5b54b454d01b51e2ad03024100bd6a2b60dee01e1b3bfcef6a2f09ed027c273cdbbaf6ba55a80f6dcc64e4509ee560f84b4f3e076bd03b11e42fe71a3fdd2dffe7e0902c8584f8cad877cdc945024100aa512fa4ada69881f1d8bb8ad6614f192b83200aef5edf4811313d5ef30a86cbd0a90f7b025c71ea06ec6b34db6306c86b1040670fd8654ad7291d066d06d031" + +var ecMismatchKeyTests = []struct { + hexKey string + errorContains string +}{ + {hexKey: hexECTestPKCS8Key, errorContains: "use ParsePKCS8PrivateKey instead"}, + {hexKey: hexECTestPKCS1Key, errorContains: "use ParsePKCS1PrivateKey instead"}, +} + +func TestECMismatchKeyFormat(t *testing.T) { + for i, test := range ecMismatchKeyTests { + derBytes, _ := hex.DecodeString(test.hexKey) + _, err := ParseECPrivateKey(derBytes) + if !strings.Contains(err.Error(), test.errorContains) { + t.Errorf("#%d: expected error containing %q, got %s", i, test.errorContains, err) + } + } +} + +func TestMarshalSM2PrivateKey(t *testing.T) { + priv, _ := sm2.GenerateKey(rand.Reader) + res, err := MarshalSM2PrivateKey(priv) + if err != nil { + t.Fatalf("%v\n", err) + } + fmt.Printf("%s\n", hex.EncodeToString(res)) +} diff --git a/smx509/verify.go b/smx509/verify.go new file mode 100644 index 0000000..52402ed --- /dev/null +++ b/smx509/verify.go @@ -0,0 +1,984 @@ +package smx509 + +import ( + "bytes" + "crypto/x509" + "errors" + "fmt" + "net" + "net/url" + "os" + "reflect" + "runtime" + "strings" + "time" + "unicode/utf8" +) + +// ignoreCN disables interpreting Common Name as a hostname. See issue 24151. +var ignoreCN = strings.Contains(os.Getenv("GODEBUG"), "x509ignoreCN=1") + +// CertificateInvalidError results when an odd error occurs. Users of this +// library probably want to handle all these errors uniformly. +type CertificateInvalidError struct { + Cert *Certificate + Reason x509.InvalidReason + Detail string +} + +func (e CertificateInvalidError) Error() string { + switch e.Reason { + case x509.NotAuthorizedToSign: + return "x509: certificate is not authorized to sign other certificates" + case x509.Expired: + return "x509: certificate has expired or is not yet valid: " + e.Detail + case x509.CANotAuthorizedForThisName: + return "x509: a root or intermediate certificate is not authorized to sign for this name: " + e.Detail + case x509.CANotAuthorizedForExtKeyUsage: + return "x509: a root or intermediate certificate is not authorized for an extended key usage: " + e.Detail + case x509.TooManyIntermediates: + return "x509: too many intermediates for path length constraint" + case x509.IncompatibleUsage: + return "x509: certificate specifies an incompatible key usage" + case x509.NameMismatch: + return "x509: issuer name does not match subject from issuing certificate" + case x509.NameConstraintsWithoutSANs: + return "x509: issuer has name constraints but leaf doesn't have a SAN extension" + case x509.UnconstrainedName: + return "x509: issuer has name constraints but leaf contains unknown or unconstrained name: " + e.Detail + } + return "x509: unknown error" +} + +// UnknownAuthorityError results when the certificate issuer is unknown +type UnknownAuthorityError struct { + Cert *Certificate + // hintErr contains an error that may be helpful in determining why an + // authority wasn't found. + hintErr error + // hintCert contains a possible authority certificate that was rejected + // because of the error in hintErr. + hintCert *Certificate +} + +func (e UnknownAuthorityError) Error() string { + s := "x509: certificate signed by unknown authority" + if e.hintErr != nil { + certName := e.hintCert.Subject.CommonName + if len(certName) == 0 { + if len(e.hintCert.Subject.Organization) > 0 { + certName = e.hintCert.Subject.Organization[0] + } else { + certName = "serial:" + e.hintCert.SerialNumber.String() + } + } + s += fmt.Sprintf(" (possibly because of %q while trying to verify candidate authority certificate %q)", e.hintErr, certName) + } + return s +} + +// SystemRootsError results when we fail to load the system root certificates. +type SystemRootsError struct { + Err error +} + +func (se SystemRootsError) Error() string { + msg := "x509: failed to load system roots and no roots provided" + if se.Err != nil { + return msg + "; " + se.Err.Error() + } + return msg +} + +// errNotParsed is returned when a certificate without ASN.1 contents is +// verified. Platform-specific verification needs the ASN.1 contents. +var errNotParsed = errors.New("x509: missing ASN.1 contents; use ParseCertificate") + +// VerifyOptions contains parameters for Certificate.Verify. +type VerifyOptions struct { + // DNSName, if set, is checked against the leaf certificate with + // Certificate.VerifyHostname or the platform verifier. + DNSName string + + // Intermediates is an optional pool of certificates that are not trust + // anchors, but can be used to form a chain from the leaf certificate to a + // root certificate. + Intermediates *CertPool + // Roots is the set of trusted root certificates the leaf certificate needs + // to chain up to. If nil, the system roots or the platform verifier are used. + Roots *CertPool + + // CurrentTime is used to check the validity of all certificates in the + // chain. If zero, the current time is used. + CurrentTime time.Time + + // KeyUsages specifies which Extended Key Usage values are acceptable. A + // chain is accepted if it allows any of the listed values. An empty list + // means ExtKeyUsageServerAuth. To accept any key usage, include ExtKeyUsageAny. + KeyUsages []x509.ExtKeyUsage + + // MaxConstraintComparisions is the maximum number of comparisons to + // perform when checking a given certificate's name constraints. If + // zero, a sensible default is used. This limit prevents pathological + // certificates from consuming excessive amounts of CPU time when + // validating. It does not apply to the platform verifier. + MaxConstraintComparisions int +} + +const ( + leafCertificate = iota + intermediateCertificate + rootCertificate +) + +// rfc2821Mailbox represents a “mailbox” (which is an email address to most +// people) by breaking it into the “local” (i.e. before the '@') and “domain” +// parts. +type rfc2821Mailbox struct { + local, domain string +} + +// parseRFC2821Mailbox parses an email address into local and domain parts, +// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, +// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The +// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. +func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { + if len(in) == 0 { + return mailbox, false + } + + localPartBytes := make([]byte, 0, len(in)/2) + + if in[0] == '"' { + // Quoted-string = DQUOTE *qcontent DQUOTE + // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 + // qcontent = qtext / quoted-pair + // qtext = non-whitespace-control / + // %d33 / %d35-91 / %d93-126 + // quoted-pair = ("\" text) / obs-qp + // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text + // + // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, + // Section 4. Since it has been 16 years, we no longer accept that.) + in = in[1:] + QuotedString: + for { + if len(in) == 0 { + return mailbox, false + } + c := in[0] + in = in[1:] + + switch { + case c == '"': + break QuotedString + + case c == '\\': + // quoted-pair + if len(in) == 0 { + return mailbox, false + } + if in[0] == 11 || + in[0] == 12 || + (1 <= in[0] && in[0] <= 9) || + (14 <= in[0] && in[0] <= 127) { + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + } else { + return mailbox, false + } + + case c == 11 || + c == 12 || + // Space (char 32) is not allowed based on the + // BNF, but RFC 3696 gives an example that + // assumes that it is. Several “verified” + // errata continue to argue about this point. + // We choose to accept it. + c == 32 || + c == 33 || + c == 127 || + (1 <= c && c <= 8) || + (14 <= c && c <= 31) || + (35 <= c && c <= 91) || + (93 <= c && c <= 126): + // qtext + localPartBytes = append(localPartBytes, c) + + default: + return mailbox, false + } + } + } else { + // Atom ("." Atom)* + NextChar: + for len(in) > 0 { + // atext from RFC 2822, Section 3.2.4 + c := in[0] + + switch { + case c == '\\': + // Examples given in RFC 3696 suggest that + // escaped characters can appear outside of a + // quoted string. Several “verified” errata + // continue to argue the point. We choose to + // accept it. + in = in[1:] + if len(in) == 0 { + return mailbox, false + } + fallthrough + + case ('0' <= c && c <= '9') || + ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || + c == '-' || c == '/' || c == '=' || c == '?' || + c == '^' || c == '_' || c == '`' || c == '{' || + c == '|' || c == '}' || c == '~' || c == '.': + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + + default: + break NextChar + } + } + + if len(localPartBytes) == 0 { + return mailbox, false + } + + // From RFC 3696, Section 3: + // “period (".") may also appear, but may not be used to start + // or end the local part, nor may two or more consecutive + // periods appear.” + twoDots := []byte{'.', '.'} + if localPartBytes[0] == '.' || + localPartBytes[len(localPartBytes)-1] == '.' || + bytes.Contains(localPartBytes, twoDots) { + return mailbox, false + } + } + + if len(in) == 0 || in[0] != '@' { + return mailbox, false + } + in = in[1:] + + // The RFC species a format for domains, but that's known to be + // violated in practice so we accept that anything after an '@' is the + // domain part. + if _, ok := domainToReverseLabels(in); !ok { + return mailbox, false + } + + mailbox.local = string(localPartBytes) + mailbox.domain = in + return mailbox, true +} + +// checkNameConstraints checks that c permits a child certificate to claim the +// given name, of type nameType. The argument parsedName contains the parsed +// form of name, suitable for passing to the match function. The total number +// of comparisons is tracked in the given count and should not exceed the given +// limit. +func (c *Certificate) checkNameConstraints(count *int, + maxConstraintComparisons int, + nameType string, + name string, + parsedName interface{}, + match func(parsedName, constraint interface{}) (match bool, err error), + permitted, excluded interface{}) error { + + excludedValue := reflect.ValueOf(excluded) + + *count += excludedValue.Len() + if *count > maxConstraintComparisons { + return CertificateInvalidError{c, x509.TooManyConstraints, ""} + } + + for i := 0; i < excludedValue.Len(); i++ { + constraint := excludedValue.Index(i).Interface() + match, err := match(parsedName, constraint) + if err != nil { + return CertificateInvalidError{c, x509.CANotAuthorizedForThisName, err.Error()} + } + + if match { + return CertificateInvalidError{c, x509.CANotAuthorizedForThisName, fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint)} + } + } + + permittedValue := reflect.ValueOf(permitted) + + *count += permittedValue.Len() + if *count > maxConstraintComparisons { + return CertificateInvalidError{c, x509.TooManyConstraints, ""} + } + + ok := true + for i := 0; i < permittedValue.Len(); i++ { + constraint := permittedValue.Index(i).Interface() + + var err error + if ok, err = match(parsedName, constraint); err != nil { + return CertificateInvalidError{c, x509.CANotAuthorizedForThisName, err.Error()} + } + + if ok { + break + } + } + + if !ok { + return CertificateInvalidError{c, x509.CANotAuthorizedForThisName, fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name)} + } + + return nil +} + +// isValid performs validity checks on c given that it is a candidate to append +// to the chain in currentChain. +func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error { + if len(c.UnhandledCriticalExtensions) > 0 { + return x509.UnhandledCriticalExtension{} + } + + if len(currentChain) > 0 { + child := currentChain[len(currentChain)-1] + if !bytes.Equal(child.RawIssuer, c.RawSubject) { + return CertificateInvalidError{c, x509.NameMismatch, ""} + } + } + + now := opts.CurrentTime + if now.IsZero() { + now = time.Now() + } + if now.Before(c.NotBefore) { + return CertificateInvalidError{ + Cert: c, + Reason: x509.Expired, + Detail: fmt.Sprintf("current time %s is before %s", now.Format(time.RFC3339), c.NotBefore.Format(time.RFC3339)), + } + } else if now.After(c.NotAfter) { + return CertificateInvalidError{ + Cert: c, + Reason: x509.Expired, + Detail: fmt.Sprintf("current time %s is after %s", now.Format(time.RFC3339), c.NotAfter.Format(time.RFC3339)), + } + } + + maxConstraintComparisons := opts.MaxConstraintComparisions + if maxConstraintComparisons == 0 { + maxConstraintComparisons = 250000 + } + comparisonCount := 0 + + var leaf *Certificate + if certType == intermediateCertificate || certType == rootCertificate { + if len(currentChain) == 0 { + return errors.New("x509: internal error: empty chain when appending CA cert") + } + leaf = currentChain[0] + } + + checkNameConstraints := (certType == intermediateCertificate || certType == rootCertificate) && c.hasNameConstraints() + if checkNameConstraints && leaf.commonNameAsHostname() { + // This is the deprecated, legacy case of depending on the commonName as + // a hostname. We don't enforce name constraints against the CN, but + // VerifyHostname will look for hostnames in there if there are no SANs. + // In order to ensure VerifyHostname will not accept an unchecked name, + // return an error here. + return CertificateInvalidError{c, x509.NameConstraintsWithoutSANs, ""} + } else if checkNameConstraints && leaf.hasSANExtension() { + err := forEachSAN(leaf.getSANExtension(), func(tag int, data []byte) error { + switch tag { + case nameTypeEmail: + name := string(data) + mailbox, ok := parseRFC2821Mailbox(name) + if !ok { + return fmt.Errorf("x509: cannot parse rfc822Name %q", mailbox) + } + + if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "email address", name, mailbox, + func(parsedName, constraint interface{}) (bool, error) { + return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string)) + }, c.PermittedEmailAddresses, c.ExcludedEmailAddresses); err != nil { + return err + } + + case nameTypeDNS: + name := string(data) + if _, ok := domainToReverseLabels(name); !ok { + return fmt.Errorf("x509: cannot parse dnsName %q", name) + } + + if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "DNS name", name, name, + func(parsedName, constraint interface{}) (bool, error) { + return matchDomainConstraint(parsedName.(string), constraint.(string)) + }, c.PermittedDNSDomains, c.ExcludedDNSDomains); err != nil { + return err + } + + case nameTypeURI: + name := string(data) + uri, err := url.Parse(name) + if err != nil { + return fmt.Errorf("x509: internal error: URI SAN %q failed to parse", name) + } + + if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "URI", name, uri, + func(parsedName, constraint interface{}) (bool, error) { + return matchURIConstraint(parsedName.(*url.URL), constraint.(string)) + }, c.PermittedURIDomains, c.ExcludedURIDomains); err != nil { + return err + } + + case nameTypeIP: + ip := net.IP(data) + if l := len(ip); l != net.IPv4len && l != net.IPv6len { + return fmt.Errorf("x509: internal error: IP SAN %x failed to parse", data) + } + + if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "IP address", ip.String(), ip, + func(parsedName, constraint interface{}) (bool, error) { + return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet)) + }, c.PermittedIPRanges, c.ExcludedIPRanges); err != nil { + return err + } + + default: + // Unknown SAN types are ignored. + } + + return nil + }) + + if err != nil { + return err + } + } + + // KeyUsage status flags are ignored. From Engineering Security, Peter + // Gutmann: A European government CA marked its signing certificates as + // being valid for encryption only, but no-one noticed. Another + // European CA marked its signature keys as not being valid for + // signatures. A different CA marked its own trusted root certificate + // as being invalid for certificate signing. Another national CA + // distributed a certificate to be used to encrypt data for the + // country’s tax authority that was marked as only being usable for + // digital signatures but not for encryption. Yet another CA reversed + // the order of the bit flags in the keyUsage due to confusion over + // encoding endianness, essentially setting a random keyUsage in + // certificates that it issued. Another CA created a self-invalidating + // certificate by adding a certificate policy statement stipulating + // that the certificate had to be used strictly as specified in the + // keyUsage, and a keyUsage containing a flag indicating that the RSA + // encryption key could only be used for Diffie-Hellman key agreement. + + if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) { + return CertificateInvalidError{c, x509.NotAuthorizedToSign, ""} + } + + if c.BasicConstraintsValid && c.MaxPathLen >= 0 { + numIntermediates := len(currentChain) - 1 + if numIntermediates > c.MaxPathLen { + return CertificateInvalidError{c, x509.TooManyIntermediates, ""} + } + } + + return nil +} + +// Verify attempts to verify c by building one or more chains from c to a +// certificate in opts.Roots, using certificates in opts.Intermediates if +// needed. If successful, it returns one or more chains where the first +// element of the chain is c and the last element is from opts.Roots. +// +// If opts.Roots is nil, the platform verifier might be used, and +// verification details might differ from what is described below. If system +// roots are unavailable the returned error will be of type SystemRootsError. +// +// Name constraints in the intermediates will be applied to all names claimed +// in the chain, not just opts.DNSName. Thus it is invalid for a leaf to claim +// example.com if an intermediate doesn't permit it, even if example.com is not +// the name being validated. Note that DirectoryName constraints are not +// supported. +// +// Extended Key Usage values are enforced nested down a chain, so an intermediate +// or root that enumerates EKUs prevents a leaf from asserting an EKU not in that +// list. (While this is not specified, it is common practice in order to limit +// the types of certificates a CA can issue.) +// +// WARNING: this function doesn't do any revocation checking. +func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) { + // Platform-specific verification needs the ASN.1 contents so + // this makes the behavior consistent across platforms. + if len(c.Raw) == 0 { + return nil, errNotParsed + } + if opts.Intermediates != nil { + for _, intermediate := range opts.Intermediates.certs { + if len(intermediate.Raw) == 0 { + return nil, errNotParsed + } + } + } + + // Use Windows's own verification and chain building. + if opts.Roots == nil && runtime.GOOS == "windows" { + return c.systemVerify(&opts) + } + + if opts.Roots == nil { + opts.Roots = systemRootsPool() + if opts.Roots == nil { + return nil, SystemRootsError{systemRootsErr} + } + } + + err = c.isValid(leafCertificate, nil, &opts) + if err != nil { + return + } + + if len(opts.DNSName) > 0 { + err = c.VerifyHostname(opts.DNSName) + if err != nil { + return + } + } + + var candidateChains [][]*Certificate + if opts.Roots.contains(c) { + candidateChains = append(candidateChains, []*Certificate{c}) + } else { + if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil { + return nil, err + } + } + + keyUsages := opts.KeyUsages + if len(keyUsages) == 0 { + keyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + } + + // If any key usage is acceptable then we're done. + for _, usage := range keyUsages { + if usage == x509.ExtKeyUsageAny { + return candidateChains, nil + } + } + + for _, candidate := range candidateChains { + if checkChainForKeyUsage(candidate, keyUsages) { + chains = append(chains, candidate) + } + } + + if len(chains) == 0 { + return nil, CertificateInvalidError{c, x509.IncompatibleUsage, ""} + } + + return chains, nil +} + +func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { + n := make([]*Certificate, len(chain)+1) + copy(n, chain) + n[len(chain)] = cert + return n +} + +// maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls +// that an invocation of buildChains will (tranistively) make. Most chains are +// less than 15 certificates long, so this leaves space for multiple chains and +// for failed checks due to different intermediates having the same Subject. +const maxChainSignatureChecks = 100 + +func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) { + var ( + hintErr error + hintCert *Certificate + ) + + considerCandidate := func(certType int, candidate *Certificate) { + for _, cert := range currentChain { + if cert.Equal(candidate) { + return + } + } + + if sigChecks == nil { + sigChecks = new(int) + } + *sigChecks++ + if *sigChecks > maxChainSignatureChecks { + err = errors.New("x509: signature check attempts limit reached while verifying certificate chain") + return + } + + if err := c.CheckSignatureFrom(candidate); err != nil { + if hintErr == nil { + hintErr = err + hintCert = candidate + } + return + } + + err = candidate.isValid(certType, currentChain, opts) + if err != nil { + return + } + + switch certType { + case rootCertificate: + chains = append(chains, appendToFreshChain(currentChain, candidate)) + case intermediateCertificate: + if cache == nil { + cache = make(map[*Certificate][][]*Certificate) + } + childChains, ok := cache[candidate] + if !ok { + childChains, err = candidate.buildChains(cache, appendToFreshChain(currentChain, candidate), sigChecks, opts) + cache[candidate] = childChains + } + chains = append(chains, childChains...) + } + } + + for _, rootNum := range opts.Roots.findPotentialParents(c) { + considerCandidate(rootCertificate, opts.Roots.certs[rootNum]) + } + for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) { + considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum]) + } + + if len(chains) > 0 { + err = nil + } + if len(chains) == 0 && err == nil { + err = UnknownAuthorityError{c, hintErr, hintCert} + } + + return +} + +// validHostname reports whether host is a valid hostname that can be matched or +// matched against according to RFC 6125 2.2, with some leniency to accommodate +// legacy values. +func validHostname(host string) bool { + host = strings.TrimSuffix(host, ".") + + if len(host) == 0 { + return false + } + + for i, part := range strings.Split(host, ".") { + if part == "" { + // Empty label. + return false + } + if i == 0 && part == "*" { + // Only allow full left-most wildcards, as those are the only ones + // we match, and matching literal '*' characters is probably never + // the expected behavior. + continue + } + for j, c := range part { + if 'a' <= c && c <= 'z' { + continue + } + if '0' <= c && c <= '9' { + continue + } + if 'A' <= c && c <= 'Z' { + continue + } + if c == '-' && j != 0 { + continue + } + if c == '_' || c == ':' { + // Not valid characters in hostnames, but commonly + // found in deployments outside the WebPKI. + continue + } + return false + } + } + + return true +} + +// commonNameAsHostname reports whether the Common Name field should be +// considered the hostname that the certificate is valid for. This is a legacy +// behavior, disabled if the Subject Alt Name extension is present. +// +// It applies the strict validHostname check to the Common Name field, so that +// certificates without SANs can still be validated against CAs with name +// constraints if there is no risk the CN would be matched as a hostname. +// See NameConstraintsWithoutSANs and issue 24151. +func (c *Certificate) commonNameAsHostname() bool { + return !ignoreCN && !c.hasSANExtension() && validHostname(c.Subject.CommonName) +} + +func matchHostnames(pattern, host string) bool { + host = strings.TrimSuffix(host, ".") + pattern = strings.TrimSuffix(pattern, ".") + + if len(pattern) == 0 || len(host) == 0 { + return false + } + + patternParts := strings.Split(pattern, ".") + hostParts := strings.Split(host, ".") + + if len(patternParts) != len(hostParts) { + return false + } + + for i, patternPart := range patternParts { + if i == 0 && patternPart == "*" { + continue + } + if patternPart != hostParts[i] { + return false + } + } + + return true +} + +// toLowerCaseASCII returns a lower-case version of in. See RFC 6125 6.4.1. We use +// an explicitly ASCII function to avoid any sharp corners resulting from +// performing Unicode operations on DNS labels. +func toLowerCaseASCII(in string) string { + // If the string is already lower-case then there's nothing to do. + isAlreadyLowerCase := true + for _, c := range in { + if c == utf8.RuneError { + // If we get a UTF-8 error then there might be + // upper-case ASCII bytes in the invalid sequence. + isAlreadyLowerCase = false + break + } + if 'A' <= c && c <= 'Z' { + isAlreadyLowerCase = false + break + } + } + + if isAlreadyLowerCase { + return in + } + + out := []byte(in) + for i, c := range out { + if 'A' <= c && c <= 'Z' { + out[i] += 'a' - 'A' + } + } + return string(out) +} + +// VerifyHostname returns nil if c is a valid certificate for the named host. +// Otherwise it returns an error describing the mismatch. +func (c *Certificate) VerifyHostname(h string) error { + // IP addresses may be written in [ ]. + candidateIP := h + if len(h) >= 3 && h[0] == '[' && h[len(h)-1] == ']' { + candidateIP = h[1 : len(h)-1] + } + if ip := net.ParseIP(candidateIP); ip != nil { + // We only match IP addresses against IP SANs. + // See RFC 6125, Appendix B.2. + for _, candidate := range c.IPAddresses { + if ip.Equal(candidate) { + return nil + } + } + return x509.HostnameError{&c.Certificate, candidateIP} + } + + lowered := toLowerCaseASCII(h) + + if c.commonNameAsHostname() { + if matchHostnames(toLowerCaseASCII(c.Subject.CommonName), lowered) { + return nil + } + } else { + for _, match := range c.DNSNames { + if matchHostnames(toLowerCaseASCII(match), lowered) { + return nil + } + } + } + + return x509.HostnameError{&c.Certificate, h} +} + +func checkChainForKeyUsage(chain []*Certificate, keyUsages []x509.ExtKeyUsage) bool { + usages := make([]x509.ExtKeyUsage, len(keyUsages)) + copy(usages, keyUsages) + + if len(chain) == 0 { + return false + } + + usagesRemaining := len(usages) + + // We walk down the list and cross out any usages that aren't supported + // by each certificate. If we cross out all the usages, then the chain + // is unacceptable. + +NextCert: + for i := len(chain) - 1; i >= 0; i-- { + cert := chain[i] + if len(cert.ExtKeyUsage) == 0 && len(cert.UnknownExtKeyUsage) == 0 { + // The certificate doesn't have any extended key usage specified. + continue + } + + for _, usage := range cert.ExtKeyUsage { + if usage == x509.ExtKeyUsageAny { + // The certificate is explicitly good for any usage. + continue NextCert + } + } + + const invalidUsage x509.ExtKeyUsage = -1 + + NextRequestedUsage: + for i, requestedUsage := range usages { + if requestedUsage == invalidUsage { + continue + } + + for _, usage := range cert.ExtKeyUsage { + if requestedUsage == usage { + continue NextRequestedUsage + } else if requestedUsage == x509.ExtKeyUsageServerAuth && + (usage == x509.ExtKeyUsageNetscapeServerGatedCrypto || + usage == x509.ExtKeyUsageMicrosoftServerGatedCrypto) { + // In order to support COMODO + // certificate chains, we have to + // accept Netscape or Microsoft SGC + // usages as equal to ServerAuth. + continue NextRequestedUsage + } + } + + usages[i] = invalidUsage + usagesRemaining-- + if usagesRemaining == 0 { + return false + } + } + } + + return true +} + +func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) { + // If the constraint contains an @, then it specifies an exact mailbox + // name. + if strings.Contains(constraint, "@") { + constraintMailbox, ok := parseRFC2821Mailbox(constraint) + if !ok { + return false, fmt.Errorf("x509: internal error: cannot parse constraint %q", constraint) + } + return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil + } + + // Otherwise the constraint is like a DNS constraint of the domain part + // of the mailbox. + return matchDomainConstraint(mailbox.domain, constraint) +} + +func matchURIConstraint(uri *url.URL, constraint string) (bool, error) { + // From RFC 5280, Section 4.2.1.10: + // “a uniformResourceIdentifier that does not include an authority + // component with a host name specified as a fully qualified domain + // name (e.g., if the URI either does not include an authority + // component or includes an authority component in which the host name + // is specified as an IP address), then the application MUST reject the + // certificate.” + + host := uri.Host + if len(host) == 0 { + return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String()) + } + + if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") { + var err error + host, _, err = net.SplitHostPort(uri.Host) + if err != nil { + return false, err + } + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") || + net.ParseIP(host) != nil { + return false, fmt.Errorf("URI with IP (%q) cannot be matched against constraints", uri.String()) + } + + return matchDomainConstraint(host, constraint) +} + +func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) { + if len(ip) != len(constraint.IP) { + return false, nil + } + + for i := range ip { + if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask { + return false, nil + } + } + + return true, nil +} + +func matchDomainConstraint(domain, constraint string) (bool, error) { + // The meaning of zero length constraints is not specified, but this + // code follows NSS and accepts them as matching everything. + if len(constraint) == 0 { + return true, nil + } + + domainLabels, ok := domainToReverseLabels(domain) + if !ok { + return false, fmt.Errorf("x509: internal error: cannot parse domain %q", domain) + } + + // RFC 5280 says that a leading period in a domain name means that at + // least one label must be prepended, but only for URI and email + // constraints, not DNS constraints. The code also supports that + // behaviour for DNS constraints. + + mustHaveSubdomains := false + if constraint[0] == '.' { + mustHaveSubdomains = true + constraint = constraint[1:] + } + + constraintLabels, ok := domainToReverseLabels(constraint) + if !ok { + return false, fmt.Errorf("x509: internal error: cannot parse domain %q", constraint) + } + + if len(domainLabels) < len(constraintLabels) || + (mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) { + return false, nil + } + + for i, constraintLabel := range constraintLabels { + if !strings.EqualFold(constraintLabel, domainLabels[i]) { + return false, nil + } + } + + return true, nil +} diff --git a/smx509/x509.go b/smx509/x509.go index 93febaf..6008755 100644 --- a/smx509/x509.go +++ b/smx509/x509.go @@ -1,9 +1,13 @@ package smx509 import ( + "bytes" "crypto" + "crypto/dsa" "crypto/ecdsa" + "crypto/ed25519" "crypto/elliptic" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -15,6 +19,11 @@ import ( "net/url" "strconv" "strings" + "time" + "unicode/utf8" + + "golang.org/x/crypto/cryptobyte" + cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" "github.com/emmansun/gmsm/sm2" ) @@ -32,12 +41,47 @@ type publicKeyInfo struct { PublicKey asn1.BitString } +// These structures reflect the ASN.1 structure of X.509 certificates.: +type validity struct { + NotBefore, NotAfter time.Time +} + +type certificate struct { + Raw asn1.RawContent + TBSCertificate tbsCertificate + SignatureAlgorithm pkix.AlgorithmIdentifier + SignatureValue asn1.BitString +} + +type tbsCertificate struct { + Raw asn1.RawContent + Version int `asn1:"optional,explicit,default:0,tag:0"` + SerialNumber *big.Int + SignatureAlgorithm pkix.AlgorithmIdentifier + Issuer asn1.RawValue + Validity validity + Subject asn1.RawValue + PublicKey publicKeyInfo + UniqueId asn1.BitString `asn1:"optional,tag:1"` + SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"` + Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` +} + +// RFC 5280, 4.2.1.1 +type authKeyId struct { + Id []byte `asn1:"optional,tag:0"` +} + // pkcs1PublicKey reflects the ASN.1 structure of a PKCS#1 public key. type pkcs1PublicKey struct { N *big.Int E int } +type dsaAlgorithmParameters struct { + P, Q, G *big.Int +} + type ecdsaSignature struct { R, S *big.Int } @@ -50,9 +94,34 @@ var ( oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} oidNamedCurveP256SM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} - oidSignatureSM2WithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} - oidSignatureSM2WithSHA1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 502} - oidSignatureSM2WithSHA256 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 503} + oidSignatureMD2WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 2} + oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4} + oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5} + oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11} + oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12} + oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13} + oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10} + oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3} + oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2} + oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1} + oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2} + oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3} + oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4} + oidSignatureEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} + oidSignatureSM2WithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} + oidSignatureSM2WithSHA1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 502} + oidSignatureSM2WithSHA256 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 503} + + oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1} + oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2} + oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3} + + oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8} + + // oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA + // but it's specified by ISO. Microsoft's makecert.exe has been known + // to produce certificates with this OID. + oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29} ) func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { @@ -88,6 +157,89 @@ func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { return nil } +var signatureAlgorithmDetails = []struct { + algo x509.SignatureAlgorithm + name string + oid asn1.ObjectIdentifier + pubKeyAlgo x509.PublicKeyAlgorithm + hash crypto.Hash +}{ + {x509.MD2WithRSA, "MD2-RSA", oidSignatureMD2WithRSA, x509.RSA, crypto.Hash(0) /* no value for MD2 */}, + {x509.MD5WithRSA, "MD5-RSA", oidSignatureMD5WithRSA, x509.RSA, crypto.MD5}, + {x509.SHA1WithRSA, "SHA1-RSA", oidSignatureSHA1WithRSA, x509.RSA, crypto.SHA1}, + {x509.SHA1WithRSA, "SHA1-RSA", oidISOSignatureSHA1WithRSA, x509.RSA, crypto.SHA1}, + {x509.SHA256WithRSA, "SHA256-RSA", oidSignatureSHA256WithRSA, x509.RSA, crypto.SHA256}, + {x509.SHA384WithRSA, "SHA384-RSA", oidSignatureSHA384WithRSA, x509.RSA, crypto.SHA384}, + {x509.SHA512WithRSA, "SHA512-RSA", oidSignatureSHA512WithRSA, x509.RSA, crypto.SHA512}, + {x509.SHA256WithRSAPSS, "SHA256-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA256}, + {x509.SHA384WithRSAPSS, "SHA384-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA384}, + {x509.SHA512WithRSAPSS, "SHA512-RSAPSS", oidSignatureRSAPSS, x509.RSA, crypto.SHA512}, + {x509.DSAWithSHA1, "DSA-SHA1", oidSignatureDSAWithSHA1, x509.DSA, crypto.SHA1}, + {x509.DSAWithSHA256, "DSA-SHA256", oidSignatureDSAWithSHA256, x509.DSA, crypto.SHA256}, + {x509.ECDSAWithSHA1, "ECDSA-SHA1", oidSignatureECDSAWithSHA1, x509.ECDSA, crypto.SHA1}, + {x509.ECDSAWithSHA256, "ECDSA-SHA256", oidSignatureECDSAWithSHA256, x509.ECDSA, crypto.SHA256}, + {x509.ECDSAWithSHA384, "ECDSA-SHA384", oidSignatureECDSAWithSHA384, x509.ECDSA, crypto.SHA384}, + {x509.ECDSAWithSHA512, "ECDSA-SHA512", oidSignatureECDSAWithSHA512, x509.ECDSA, crypto.SHA512}, + {x509.PureEd25519, "Ed25519", oidSignatureEd25519, x509.Ed25519, crypto.Hash(0) /* no pre-hashing */}, +} + +// pssParameters reflects the parameters in an AlgorithmIdentifier that +// specifies RSA PSS. See RFC 3447, Appendix A.2.3. +type pssParameters struct { + // The following three fields are not marked as + // optional because the default values specify SHA-1, + // which is no longer suitable for use in signatures. + Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"` + MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"` + SaltLength int `asn1:"explicit,tag:2"` + TrailerField int `asn1:"optional,explicit,tag:3,default:1"` +} + +// rsaPSSParameters returns an asn1.RawValue suitable for use as the Parameters +// in an AlgorithmIdentifier that specifies RSA PSS. +func rsaPSSParameters(hashFunc crypto.Hash) asn1.RawValue { + var hashOID asn1.ObjectIdentifier + + switch hashFunc { + case crypto.SHA256: + hashOID = oidSHA256 + case crypto.SHA384: + hashOID = oidSHA384 + case crypto.SHA512: + hashOID = oidSHA512 + } + + params := pssParameters{ + Hash: pkix.AlgorithmIdentifier{ + Algorithm: hashOID, + Parameters: asn1.NullRawValue, + }, + MGF: pkix.AlgorithmIdentifier{ + Algorithm: oidMGF1, + }, + SaltLength: hashFunc.Size(), + TrailerField: 1, + } + + mgf1Params := pkix.AlgorithmIdentifier{ + Algorithm: hashOID, + Parameters: asn1.NullRawValue, + } + + var err error + params.MGF.Parameters.FullBytes, err = asn1.Marshal(mgf1Params) + if err != nil { + panic(err) + } + + serialized, err := asn1.Marshal(params) + if err != nil { + panic(err) + } + + return asn1.RawValue{FullBytes: serialized} +} + // ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. // // It returns a *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, or @@ -142,13 +294,11 @@ func ParsePKIXPublicKey(derBytes []byte) (interface{}, error) { // // This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY". func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { - ecdPub, ok := pub.(*ecdsa.PublicKey) - if !ok || ecdPub.Curve != sm2.P256() { - return x509.MarshalPKIXPublicKey(pub) - } + var publicKeyBytes []byte + var publicKeyAlgorithm pkix.AlgorithmIdentifier + var err error - publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(ecdPub) - if err != nil { + if publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(pub); err != nil { return nil, err } @@ -160,15 +310,45 @@ func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { }, } - return asn1.Marshal(pkix) + ret, _ := asn1.Marshal(pkix) + return ret, nil } -func marshalPublicKey(pub *ecdsa.PublicKey) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { - publicKeyAlgorithm = pkix.AlgorithmIdentifier{Algorithm: oidPublicKeyECDSA} - publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) - paramBytes, err := asn1.Marshal(oidNamedCurveP256SM2) - publicKeyAlgorithm.Parameters.FullBytes = paramBytes - return +func marshalPublicKey(pub interface{}) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + publicKeyBytes, err = asn1.Marshal(pkcs1PublicKey{ + N: pub.N, + E: pub.E, + }) + if err != nil { + return nil, pkix.AlgorithmIdentifier{}, err + } + publicKeyAlgorithm.Algorithm = oidPublicKeyRSA + // This is a NULL parameters value which is required by + // RFC 3279, Section 2.3.1. + publicKeyAlgorithm.Parameters = asn1.NullRawValue + case *ecdsa.PublicKey: + publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) + oid, ok := oidFromNamedCurve(pub.Curve) + if !ok { + return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve") + } + publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA + var paramBytes []byte + paramBytes, err = asn1.Marshal(oid) + if err != nil { + return + } + publicKeyAlgorithm.Parameters.FullBytes = paramBytes + case ed25519.PublicKey: + publicKeyBytes = pub + publicKeyAlgorithm.Algorithm = oidPublicKeyEd25519 + default: + return nil, pkix.AlgorithmIdentifier{}, fmt.Errorf("x509: unsupported public key type: %T", pub) + } + + return publicKeyBytes, publicKeyAlgorithm, nil } // CreateCertificateRequest creates a new certificate request based on a @@ -197,7 +377,11 @@ func CreateCertificateRequest(rand io.Reader, template *x509.CertificateRequest, } privKey, ok := key.(*sm2.PrivateKey) if !ok { - return x509.CreateCertificateRequest(rand, template, priv) + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok || ecKey.Curve != sm2.P256() { + return x509.CreateCertificateRequest(rand, template, priv) + } + privKey, _ = new(sm2.PrivateKey).FromECPrivateKey(ecKey) } var sigAlgo = pkix.AlgorithmIdentifier{} sigAlgo.Algorithm = oidSignatureSM2WithSM3 @@ -448,7 +632,7 @@ func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris [ // ParseCertificateRequest parses a single certificate request from the // given ASN.1 DER data. -func ParseCertificateRequest(asn1Data []byte) (*x509.CertificateRequest, error) { +func ParseCertificateRequest(asn1Data []byte) (*CertificateRequest, error) { var csr certificateRequest rest, err := asn1.Unmarshal(asn1Data, &csr) @@ -458,16 +642,20 @@ func ParseCertificateRequest(asn1Data []byte) (*x509.CertificateRequest, error) return nil, asn1.SyntaxError{Msg: "trailing data"} } if !csr.SignatureAlgorithm.Algorithm.Equal(oidSignatureSM2WithSM3) { - return x509.ParseCertificateRequest(asn1Data) + csrR, err := x509.ParseCertificateRequest(asn1Data) + if err != nil { + return nil, err + } + return &CertificateRequest{*csrR}, nil } return parseCertificateRequest(&csr) } -func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, error) { +func parseCertificateRequest(in *certificateRequest) (*CertificateRequest, error) { if !oidSignatureSM2WithSM3.Equal(in.SignatureAlgorithm.Algorithm) { return nil, errors.New("unsupport signature algorithm") } - out := &x509.CertificateRequest{ + out := &CertificateRequest{x509.CertificateRequest{ Raw: in.Raw, RawTBSCertificateRequest: in.TBSCSR.Raw, RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw, @@ -479,10 +667,11 @@ func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, Version: in.TBSCSR.Version, Attributes: parseRawAttributes(in.TBSCSR.RawAttributes), + }, } var err error - out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey) + out.PublicKey, err = parsePublicKey(out.PublicKeyAlgorithm, &in.TBSCSR.PublicKey) if err != nil { return nil, err } @@ -512,31 +701,106 @@ func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, return out, nil } -func parsePublicKey(keyData *publicKeyInfo) (interface{}, error) { +func parsePublicKey(algo x509.PublicKeyAlgorithm, keyData *publicKeyInfo) (interface{}, error) { asn1Data := keyData.PublicKey.RightAlign() - paramsData := keyData.Algorithm.Parameters.FullBytes - namedCurveOID := new(asn1.ObjectIdentifier) - rest, err := asn1.Unmarshal(paramsData, namedCurveOID) - if err != nil { - return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") + switch algo { + case x509.RSA: + // RSA public keys must have a NULL in the parameters. + // See RFC 3279, Section 2.3.1. + if !bytes.Equal(keyData.Algorithm.Parameters.FullBytes, asn1.NullBytes) { + return nil, errors.New("x509: RSA key missing NULL parameters") + } + + p := new(pkcs1PublicKey) + rest, err := asn1.Unmarshal(asn1Data, p) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after RSA public key") + } + + if p.N.Sign() <= 0 { + return nil, errors.New("x509: RSA modulus is not a positive number") + } + if p.E <= 0 { + return nil, errors.New("x509: RSA public exponent is not a positive number") + } + + pub := &rsa.PublicKey{ + E: p.E, + N: p.N, + } + return pub, nil + case x509.DSA: + var p *big.Int + rest, err := asn1.Unmarshal(asn1Data, &p) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after DSA public key") + } + paramsData := keyData.Algorithm.Parameters.FullBytes + params := new(dsaAlgorithmParameters) + rest, err = asn1.Unmarshal(paramsData, params) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after DSA parameters") + } + if p.Sign() <= 0 || params.P.Sign() <= 0 || params.Q.Sign() <= 0 || params.G.Sign() <= 0 { + return nil, errors.New("x509: zero or negative DSA parameter") + } + pub := &dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: params.P, + Q: params.Q, + G: params.G, + }, + Y: p, + } + return pub, nil + case x509.ECDSA: + paramsData := keyData.Algorithm.Parameters.FullBytes + namedCurveOID := new(asn1.ObjectIdentifier) + rest, err := asn1.Unmarshal(paramsData, namedCurveOID) + if err != nil { + return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ECDSA parameters") + } + namedCurve := namedCurveFromOID(*namedCurveOID) + if namedCurve == nil { + return nil, errors.New("x509: unsupported elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, asn1Data) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil + case x509.Ed25519: + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(keyData.Algorithm.Parameters.FullBytes) != 0 { + return nil, errors.New("x509: Ed25519 key encoded with illegal parameters") + } + if len(asn1Data) != ed25519.PublicKeySize { + return nil, errors.New("x509: wrong Ed25519 public key size") + } + pub := make([]byte, ed25519.PublicKeySize) + copy(pub, asn1Data) + return ed25519.PublicKey(pub), nil + default: + return nil, nil } - if len(rest) != 0 { - return nil, errors.New("x509: trailing data after ECDSA parameters") - } - namedCurve := namedCurveFromOID(*namedCurveOID) - if namedCurve == nil { - return nil, errors.New("x509: unsupported elliptic curve") - } - x, y := elliptic.Unmarshal(namedCurve, asn1Data) - if x == nil { - return nil, errors.New("x509: failed to unmarshal elliptic curve point") - } - pub := &ecdsa.PublicKey{ - Curve: namedCurve, - X: x, - Y: y, - } - return pub, nil } // parseRawAttributes Unmarshals RawAttributes into AttributeTypeAndValueSETs. @@ -712,14 +976,22 @@ func parseSANExtension(value []byte) (dnsNames, emailAddresses []string, ipAddre // id-ecPublicKey OBJECT IDENTIFIER ::= { // iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } var ( - oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} - oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} - oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + oidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} ) func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) x509.PublicKeyAlgorithm { - if oid.Equal(oidPublicKeyECDSA) { + switch { + case oid.Equal(oidPublicKeyRSA): + return x509.RSA + case oid.Equal(oidPublicKeyDSA): + return x509.DSA + case oid.Equal(oidPublicKeyECDSA): return x509.ECDSA + case oid.Equal(oidPublicKeyEd25519): + return x509.Ed25519 } return x509.UnknownPublicKeyAlgorithm } @@ -735,6 +1007,16 @@ func CheckSignature(c *x509.CertificateRequest) error { return c.CheckSignature() } +// CertificateRequest represents a PKCS #10, certificate signature request. +type CertificateRequest struct { + x509.CertificateRequest +} + +// CheckSignature reports whether the signature on c is valid. +func (c *CertificateRequest) CheckSignature() error { + return CheckSignature(&c.CertificateRequest) +} + // CheckSignature verifies that signature is a valid signature over signed from // a crypto.PublicKey. func checkSignature(c *x509.CertificateRequest, publicKey *ecdsa.PublicKey) (err error) { @@ -753,3 +1035,1311 @@ func checkSignature(c *x509.CertificateRequest, publicKey *ecdsa.PublicKey) (err } return } + +func reverseBitsInAByte(in byte) byte { + b1 := in>>4 | in<<4 + b2 := b1>>2&0x33 | b1<<2&0xcc + b3 := b2>>1&0x55 | b2<<1&0xaa + return b3 +} + +// asn1BitLength returns the bit-length of bitString by considering the +// most-significant bit in a byte to be the "first" bit. This convention +// matches ASN.1, but differs from almost everything else. +func asn1BitLength(bitString []byte) int { + bitLen := len(bitString) * 8 + + for i := range bitString { + b := bitString[len(bitString)-i-1] + + for bit := uint(0); bit < 8; bit++ { + if (b>>bit)&1 == 1 { + return bitLen + } + bitLen-- + } + } + + return 0 +} + +// RFC 5280, 4.2.1.12 Extended Key Usage +// +// anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 } +// +// id-kp OBJECT IDENTIFIER ::= { id-pkix 3 } +// +// id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 } +// id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 } +// id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 } +// id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 } +// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 } +// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 } +var ( + oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0} + oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} + oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} + oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3} + oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4} + oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5} + oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6} + oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7} + oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8} + oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9} + oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3} + oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1} + oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22} + oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1} +) + +// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID. +var extKeyUsageOIDs = []struct { + extKeyUsage x509.ExtKeyUsage + oid asn1.ObjectIdentifier +}{ + {x509.ExtKeyUsageAny, oidExtKeyUsageAny}, + {x509.ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth}, + {x509.ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth}, + {x509.ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning}, + {x509.ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection}, + {x509.ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem}, + {x509.ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel}, + {x509.ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser}, + {x509.ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping}, + {x509.ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning}, + {x509.ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto}, + {x509.ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto}, + {x509.ExtKeyUsageMicrosoftCommercialCodeSigning, oidExtKeyUsageMicrosoftCommercialCodeSigning}, + {x509.ExtKeyUsageMicrosoftKernelCodeSigning, oidExtKeyUsageMicrosoftKernelCodeSigning}, +} + +func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku x509.ExtKeyUsage, ok bool) { + for _, pair := range extKeyUsageOIDs { + if oid.Equal(pair.oid) { + return pair.extKeyUsage, true + } + } + return +} + +func oidFromExtKeyUsage(eku x509.ExtKeyUsage) (oid asn1.ObjectIdentifier, ok bool) { + for _, pair := range extKeyUsageOIDs { + if eku == pair.extKeyUsage { + return pair.oid, true + } + } + return +} + +type basicConstraints struct { + IsCA bool `asn1:"optional"` + MaxPathLen int `asn1:"optional,default:-1"` +} + +// RFC 5280, 4.2.2.1 +type authorityInfoAccess struct { + Method asn1.ObjectIdentifier + Location asn1.RawValue +} + +// RFC 5280 4.2.1.4 +type policyInformation struct { + Policy asn1.ObjectIdentifier + // policyQualifiers omitted +} + +var ( + oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1} + oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2} +) + +func isIA5String(s string) error { + for _, r := range s { + if r >= utf8.RuneSelf { + return fmt.Errorf("x509: %q cannot be encoded as an IA5String", s) + } + } + + return nil +} + +type distributionPointName struct { + FullName []asn1.RawValue `asn1:"optional,tag:0"` + RelativeName pkix.RDNSequence `asn1:"optional,tag:1"` +} + +// RFC 5280, 4.2.1.14 +type distributionPoint struct { + DistributionPoint distributionPointName `asn1:"optional,tag:0"` + Reason asn1.BitString `asn1:"optional,tag:1"` + CRLIssuer asn1.RawValue `asn1:"optional,tag:2"` +} + +func buildExtensions(template *x509.Certificate, subjectIsEmpty bool, authorityKeyId []byte) (ret []pkix.Extension, err error) { + ret = make([]pkix.Extension, 10 /* maximum number of elements. */) + n := 0 + + if template.KeyUsage != 0 && + !oidInExtensions(oidExtensionKeyUsage, template.ExtraExtensions) { + ret[n].Id = oidExtensionKeyUsage + ret[n].Critical = true + + var a [2]byte + a[0] = reverseBitsInAByte(byte(template.KeyUsage)) + a[1] = reverseBitsInAByte(byte(template.KeyUsage >> 8)) + + l := 1 + if a[1] != 0 { + l = 2 + } + + bitString := a[:l] + ret[n].Value, err = asn1.Marshal(asn1.BitString{Bytes: bitString, BitLength: asn1BitLength(bitString)}) + if err != nil { + return + } + n++ + } + + if (len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0) && + !oidInExtensions(oidExtensionExtendedKeyUsage, template.ExtraExtensions) { + ret[n].Id = oidExtensionExtendedKeyUsage + + var oids []asn1.ObjectIdentifier + for _, u := range template.ExtKeyUsage { + if oid, ok := oidFromExtKeyUsage(u); ok { + oids = append(oids, oid) + } else { + panic("internal error") + } + } + + oids = append(oids, template.UnknownExtKeyUsage...) + + ret[n].Value, err = asn1.Marshal(oids) + if err != nil { + return + } + n++ + } + + if template.BasicConstraintsValid && !oidInExtensions(oidExtensionBasicConstraints, template.ExtraExtensions) { + // Leaving MaxPathLen as zero indicates that no maximum path + // length is desired, unless MaxPathLenZero is set. A value of + // -1 causes encoding/asn1 to omit the value as desired. + maxPathLen := template.MaxPathLen + if maxPathLen == 0 && !template.MaxPathLenZero { + maxPathLen = -1 + } + ret[n].Id = oidExtensionBasicConstraints + ret[n].Value, err = asn1.Marshal(basicConstraints{template.IsCA, maxPathLen}) + ret[n].Critical = true + if err != nil { + return + } + n++ + } + + if len(template.SubjectKeyId) > 0 && !oidInExtensions(oidExtensionSubjectKeyId, template.ExtraExtensions) { + ret[n].Id = oidExtensionSubjectKeyId + ret[n].Value, err = asn1.Marshal(template.SubjectKeyId) + if err != nil { + return + } + n++ + } + + if len(authorityKeyId) > 0 && !oidInExtensions(oidExtensionAuthorityKeyId, template.ExtraExtensions) { + ret[n].Id = oidExtensionAuthorityKeyId + ret[n].Value, err = asn1.Marshal(authKeyId{authorityKeyId}) + if err != nil { + return + } + n++ + } + + if (len(template.OCSPServer) > 0 || len(template.IssuingCertificateURL) > 0) && + !oidInExtensions(oidExtensionAuthorityInfoAccess, template.ExtraExtensions) { + ret[n].Id = oidExtensionAuthorityInfoAccess + var aiaValues []authorityInfoAccess + for _, name := range template.OCSPServer { + aiaValues = append(aiaValues, authorityInfoAccess{ + Method: oidAuthorityInfoAccessOcsp, + Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, + }) + } + for _, name := range template.IssuingCertificateURL { + aiaValues = append(aiaValues, authorityInfoAccess{ + Method: oidAuthorityInfoAccessIssuers, + Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, + }) + } + ret[n].Value, err = asn1.Marshal(aiaValues) + if err != nil { + return + } + n++ + } + + if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) && + !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { + ret[n].Id = oidExtensionSubjectAltName + // From RFC 5280, Section 4.2.1.6: + // “If the subject field contains an empty sequence ... then + // subjectAltName extension ... is marked as critical” + ret[n].Critical = subjectIsEmpty + ret[n].Value, err = marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs) + if err != nil { + return + } + n++ + } + + if len(template.PolicyIdentifiers) > 0 && + !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { + ret[n].Id = oidExtensionCertificatePolicies + policies := make([]policyInformation, len(template.PolicyIdentifiers)) + for i, policy := range template.PolicyIdentifiers { + policies[i].Policy = policy + } + ret[n].Value, err = asn1.Marshal(policies) + if err != nil { + return + } + n++ + } + + if (len(template.PermittedDNSDomains) > 0 || len(template.ExcludedDNSDomains) > 0 || + len(template.PermittedIPRanges) > 0 || len(template.ExcludedIPRanges) > 0 || + len(template.PermittedEmailAddresses) > 0 || len(template.ExcludedEmailAddresses) > 0 || + len(template.PermittedURIDomains) > 0 || len(template.ExcludedURIDomains) > 0) && + !oidInExtensions(oidExtensionNameConstraints, template.ExtraExtensions) { + ret[n].Id = oidExtensionNameConstraints + ret[n].Critical = template.PermittedDNSDomainsCritical + + ipAndMask := func(ipNet *net.IPNet) []byte { + maskedIP := ipNet.IP.Mask(ipNet.Mask) + ipAndMask := make([]byte, 0, len(maskedIP)+len(ipNet.Mask)) + ipAndMask = append(ipAndMask, maskedIP...) + ipAndMask = append(ipAndMask, ipNet.Mask...) + return ipAndMask + } + + serialiseConstraints := func(dns []string, ips []*net.IPNet, emails []string, uriDomains []string) (der []byte, err error) { + var b cryptobyte.Builder + + for _, name := range dns { + if err = isIA5String(name); err != nil { + return nil, err + } + + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1(cryptobyte_asn1.Tag(2).ContextSpecific(), func(b *cryptobyte.Builder) { + b.AddBytes([]byte(name)) + }) + }) + } + + for _, ipNet := range ips { + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1(cryptobyte_asn1.Tag(7).ContextSpecific(), func(b *cryptobyte.Builder) { + b.AddBytes(ipAndMask(ipNet)) + }) + }) + } + + for _, email := range emails { + if err = isIA5String(email); err != nil { + return nil, err + } + + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific(), func(b *cryptobyte.Builder) { + b.AddBytes([]byte(email)) + }) + }) + } + + for _, uriDomain := range uriDomains { + if err = isIA5String(uriDomain); err != nil { + return nil, err + } + + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1(cryptobyte_asn1.Tag(6).ContextSpecific(), func(b *cryptobyte.Builder) { + b.AddBytes([]byte(uriDomain)) + }) + }) + } + + return b.Bytes() + } + + permitted, err := serialiseConstraints(template.PermittedDNSDomains, template.PermittedIPRanges, template.PermittedEmailAddresses, template.PermittedURIDomains) + if err != nil { + return nil, err + } + + excluded, err := serialiseConstraints(template.ExcludedDNSDomains, template.ExcludedIPRanges, template.ExcludedEmailAddresses, template.ExcludedURIDomains) + if err != nil { + return nil, err + } + + var b cryptobyte.Builder + b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { + if len(permitted) > 0 { + b.AddASN1(cryptobyte_asn1.Tag(0).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) { + b.AddBytes(permitted) + }) + } + + if len(excluded) > 0 { + b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) { + b.AddBytes(excluded) + }) + } + }) + + ret[n].Value, err = b.Bytes() + if err != nil { + return nil, err + } + n++ + } + + if len(template.CRLDistributionPoints) > 0 && + !oidInExtensions(oidExtensionCRLDistributionPoints, template.ExtraExtensions) { + ret[n].Id = oidExtensionCRLDistributionPoints + + var crlDp []distributionPoint + for _, name := range template.CRLDistributionPoints { + dp := distributionPoint{ + DistributionPoint: distributionPointName{ + FullName: []asn1.RawValue{ + {Tag: 6, Class: 2, Bytes: []byte(name)}, + }, + }, + } + crlDp = append(crlDp, dp) + } + + ret[n].Value, err = asn1.Marshal(crlDp) + if err != nil { + return + } + n++ + } + + // Adding another extension here? Remember to update the maximum number + // of elements in the make() at the top of the function and the list of + // template fields used in CreateCertificate documentation. + + return append(ret[:n], template.ExtraExtensions...), nil +} + +func subjectBytes(cert *x509.Certificate) ([]byte, error) { + if len(cert.RawSubject) > 0 { + return cert.RawSubject, nil + } + + return asn1.Marshal(cert.Subject.ToRDNSequence()) +} + +// emptyASN1Subject is the ASN.1 DER encoding of an empty Subject, which is +// just an empty SEQUENCE. +var emptyASN1Subject = []byte{0x30, 0} + +// A Certificate represents an X.509 certificate. +type Certificate struct { + x509.Certificate +} + +func (c *Certificate) Equal(other *Certificate) bool { + if c == nil || other == nil { + return c == other + } + return bytes.Equal(c.Raw, other.Raw) +} + +func (c *Certificate) hasNameConstraints() bool { + return oidInExtensions(oidExtensionNameConstraints, c.Extensions) +} + +func (c *Certificate) hasSANExtension() bool { + return oidInExtensions(oidExtensionSubjectAltName, c.Extensions) +} + +func (c *Certificate) getSANExtension() []byte { + for _, e := range c.Extensions { + if e.Id.Equal(oidExtensionSubjectAltName) { + return e.Value + } + } + return nil +} + +func isRSAPSS(algo x509.SignatureAlgorithm) bool { + switch algo { + case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS: + return true + default: + return false + } +} + +// signingParamsForPublicKey returns the parameters to use for signing with +// priv. If requestedSigAlgo is not zero then it overrides the default +// signature algorithm. +func signingParamsForPublicKey(pub interface{}, requestedSigAlgo x509.SignatureAlgorithm) (hashFunc crypto.Hash, sigAlgo pkix.AlgorithmIdentifier, err error) { + var pubType x509.PublicKeyAlgorithm + + switch pub := pub.(type) { + case *rsa.PublicKey: + pubType = x509.RSA + hashFunc = crypto.SHA256 + sigAlgo.Algorithm = oidSignatureSHA256WithRSA + sigAlgo.Parameters = asn1.NullRawValue + + case *ecdsa.PublicKey: + pubType = x509.ECDSA + + switch pub.Curve { + case elliptic.P224(), elliptic.P256(): + hashFunc = crypto.SHA256 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA256 + case elliptic.P384(): + hashFunc = crypto.SHA384 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA384 + case elliptic.P521(): + hashFunc = crypto.SHA512 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA512 + case sm2.P256(): + hashFunc = crypto.Hash(0) + sigAlgo.Algorithm = oidSignatureSM2WithSM3 + default: + err = errors.New("x509: unknown elliptic curve") + } + + case ed25519.PublicKey: + pubType = x509.Ed25519 + sigAlgo.Algorithm = oidSignatureEd25519 + + default: + err = errors.New("x509: only RSA, ECDSA and Ed25519 keys supported") + } + + if err != nil { + return + } + + if requestedSigAlgo == 0 { + return + } + + found := false + for _, details := range signatureAlgorithmDetails { + if details.algo == requestedSigAlgo { + if details.pubKeyAlgo != pubType { + err = errors.New("x509: requested SignatureAlgorithm does not match private key type") + return + } + sigAlgo.Algorithm, hashFunc = details.oid, details.hash + if hashFunc == 0 && pubType != x509.Ed25519 { + err = errors.New("x509: cannot sign with hash function requested") + return + } + if isRSAPSS(requestedSigAlgo) { + sigAlgo.Parameters = rsaPSSParameters(hashFunc) + } + found = true + break + } + } + + if !found { + err = errors.New("x509: unknown SignatureAlgorithm") + } + + return +} + +// CreateCertificate creates a new X.509v3 certificate based on a template. +// The following members of template are used: +// +// - AuthorityKeyId +// - BasicConstraintsValid +// - CRLDistributionPoints +// - DNSNames +// - EmailAddresses +// - ExcludedDNSDomains +// - ExcludedEmailAddresses +// - ExcludedIPRanges +// - ExcludedURIDomains +// - ExtKeyUsage +// - ExtraExtensions +// - IPAddresses +// - IsCA +// - IssuingCertificateURL +// - KeyUsage +// - MaxPathLen +// - MaxPathLenZero +// - NotAfter +// - NotBefore +// - OCSPServer +// - PermittedDNSDomains +// - PermittedDNSDomainsCritical +// - PermittedEmailAddresses +// - PermittedIPRanges +// - PermittedURIDomains +// - PolicyIdentifiers +// - SerialNumber +// - SignatureAlgorithm +// - Subject +// - SubjectKeyId +// - URIs +// - UnknownExtKeyUsage +// +// The certificate is signed by parent. If parent is equal to template then the +// certificate is self-signed. The parameter pub is the public key of the +// signee and priv is the private key of the signer. +// +// The returned slice is the certificate in DER encoding. +// +// The currently supported key types are *rsa.PublicKey, *ecdsa.PublicKey and +// ed25519.PublicKey. pub must be a supported key type, and priv must be a +// crypto.Signer with a supported public key. +// +// The AuthorityKeyId will be taken from the SubjectKeyId of parent, if any, +// unless the resulting certificate is self-signed. Otherwise the value from +// template will be used. +func CreateCertificate(rand io.Reader, template, parent *x509.Certificate, pub, priv interface{}) (cert []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm) + if err != nil { + return nil, err + } + if template.SerialNumber == nil { + return nil, errors.New("x509: no SerialNumber given") + } + publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(pub) + if err != nil { + return nil, err + } + + asn1Issuer, err := subjectBytes(parent) + if err != nil { + return + } + + asn1Subject, err := subjectBytes(template) + if err != nil { + return + } + + authorityKeyId := template.AuthorityKeyId + if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 { + authorityKeyId = parent.SubjectKeyId + } + + extensions, err := buildExtensions(template, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId) + if err != nil { + return + } + + encodedPublicKey := asn1.BitString{BitLength: len(publicKeyBytes) * 8, Bytes: publicKeyBytes} + c := tbsCertificate{ + Version: 2, + SerialNumber: template.SerialNumber, + SignatureAlgorithm: signatureAlgorithm, + Issuer: asn1.RawValue{FullBytes: asn1Issuer}, + Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()}, + Subject: asn1.RawValue{FullBytes: asn1Subject}, + PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey}, + Extensions: extensions, + } + + tbsCertContents, err := asn1.Marshal(c) + if err != nil { + return + } + c.Raw = tbsCertContents + + signed := tbsCertContents + + var signature []byte + if signatureAlgorithm.Algorithm.Equal(oidSignatureSM2WithSM3) { + privKey, ok := key.(*sm2.PrivateKey) + if !ok { + ecKey, ok := key.(*ecdsa.PrivateKey) + if ok && ecKey.Curve == sm2.P256() { + privKey, _ = new(sm2.PrivateKey).FromECPrivateKey(ecKey) + } + } + signature, err = privKey.SignWithSM2(rand, nil, signed) + } else { + if hashFunc != 0 { + h := hashFunc.New() + h.Write(signed) + signed = h.Sum(nil) + } + + var signerOpts crypto.SignerOpts = hashFunc + if template.SignatureAlgorithm != 0 && isRSAPSS(template.SignatureAlgorithm) { + signerOpts = &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: hashFunc, + } + } + + signature, err = key.Sign(rand, signed, signerOpts) + } + if err != nil { + return + } + + return asn1.Marshal(certificate{ + nil, + c, + signatureAlgorithm, + asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, + }) +} + +// isValidIPMask reports whether mask consists of zero or more 1 bits, followed by zero bits. +func isValidIPMask(mask []byte) bool { + seenZero := false + + for _, b := range mask { + if seenZero { + if b != 0 { + return false + } + + continue + } + + switch b { + case 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe: + seenZero = true + case 0xff: + default: + return false + } + } + + return true +} + +func parseNameConstraintsExtension(out *x509.Certificate, e pkix.Extension) (unhandled bool, err error) { + // RFC 5280, 4.2.1.10 + + // NameConstraints ::= SEQUENCE { + // permittedSubtrees [0] GeneralSubtrees OPTIONAL, + // excludedSubtrees [1] GeneralSubtrees OPTIONAL } + // + // GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree + // + // GeneralSubtree ::= SEQUENCE { + // base GeneralName, + // minimum [0] BaseDistance DEFAULT 0, + // maximum [1] BaseDistance OPTIONAL } + // + // BaseDistance ::= INTEGER (0..MAX) + + outer := cryptobyte.String(e.Value) + var toplevel, permitted, excluded cryptobyte.String + var havePermitted, haveExcluded bool + if !outer.ReadASN1(&toplevel, cryptobyte_asn1.SEQUENCE) || + !outer.Empty() || + !toplevel.ReadOptionalASN1(&permitted, &havePermitted, cryptobyte_asn1.Tag(0).ContextSpecific().Constructed()) || + !toplevel.ReadOptionalASN1(&excluded, &haveExcluded, cryptobyte_asn1.Tag(1).ContextSpecific().Constructed()) || + !toplevel.Empty() { + return false, errors.New("x509: invalid NameConstraints extension") + } + + if !havePermitted && !haveExcluded || len(permitted) == 0 && len(excluded) == 0 { + // From RFC 5280, Section 4.2.1.10: + // “either the permittedSubtrees field + // or the excludedSubtrees MUST be + // present” + return false, errors.New("x509: empty name constraints extension") + } + + getValues := func(subtrees cryptobyte.String) (dnsNames []string, ips []*net.IPNet, emails, uriDomains []string, err error) { + for !subtrees.Empty() { + var seq, value cryptobyte.String + var tag cryptobyte_asn1.Tag + if !subtrees.ReadASN1(&seq, cryptobyte_asn1.SEQUENCE) || + !seq.ReadAnyASN1(&value, &tag) { + return nil, nil, nil, nil, fmt.Errorf("x509: invalid NameConstraints extension") + } + + var ( + dnsTag = cryptobyte_asn1.Tag(2).ContextSpecific() + emailTag = cryptobyte_asn1.Tag(1).ContextSpecific() + ipTag = cryptobyte_asn1.Tag(7).ContextSpecific() + uriTag = cryptobyte_asn1.Tag(6).ContextSpecific() + ) + + switch tag { + case dnsTag: + domain := string(value) + if err := isIA5String(domain); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + trimmedDomain := domain + if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { + // constraints can have a leading + // period to exclude the domain + // itself, but that's not valid in a + // normal domain name. + trimmedDomain = trimmedDomain[1:] + } + if _, ok := domainToReverseLabels(trimmedDomain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse dnsName constraint %q", domain) + } + dnsNames = append(dnsNames, domain) + + case ipTag: + l := len(value) + var ip, mask []byte + + switch l { + case 8: + ip = value[:4] + mask = value[4:] + + case 32: + ip = value[:16] + mask = value[16:] + + default: + return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained value of length %d", l) + } + + if !isValidIPMask(mask) { + return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained invalid mask %x", mask) + } + + ips = append(ips, &net.IPNet{IP: net.IP(ip), Mask: net.IPMask(mask)}) + + case emailTag: + constraint := string(value) + if err := isIA5String(constraint); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + // If the constraint contains an @ then + // it specifies an exact mailbox name. + if strings.Contains(constraint, "@") { + if _, ok := parseRFC2821Mailbox(constraint); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) + } + } else { + // Otherwise it's a domain name. + domain := constraint + if len(domain) > 0 && domain[0] == '.' { + domain = domain[1:] + } + if _, ok := domainToReverseLabels(domain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) + } + } + emails = append(emails, constraint) + + case uriTag: + domain := string(value) + if err := isIA5String(domain); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + if net.ParseIP(domain) != nil { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q: cannot be IP address", domain) + } + + trimmedDomain := domain + if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { + // constraints can have a leading + // period to exclude the domain itself, + // but that's not valid in a normal + // domain name. + trimmedDomain = trimmedDomain[1:] + } + if _, ok := domainToReverseLabels(trimmedDomain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q", domain) + } + uriDomains = append(uriDomains, domain) + + default: + unhandled = true + } + } + + return dnsNames, ips, emails, uriDomains, nil + } + + if out.PermittedDNSDomains, out.PermittedIPRanges, out.PermittedEmailAddresses, out.PermittedURIDomains, err = getValues(permitted); err != nil { + return false, err + } + if out.ExcludedDNSDomains, out.ExcludedIPRanges, out.ExcludedEmailAddresses, out.ExcludedURIDomains, err = getValues(excluded); err != nil { + return false, err + } + out.PermittedDNSDomainsCritical = e.Critical + + return unhandled, nil +} + +func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) x509.SignatureAlgorithm { + if ai.Algorithm.Equal(oidSignatureEd25519) { + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(ai.Parameters.FullBytes) != 0 { + return x509.UnknownSignatureAlgorithm + } + } + + if !ai.Algorithm.Equal(oidSignatureRSAPSS) { + for _, details := range signatureAlgorithmDetails { + if ai.Algorithm.Equal(details.oid) { + return details.algo + } + } + return x509.UnknownSignatureAlgorithm + } + + // RSA PSS is special because it encodes important parameters + // in the Parameters. + + var params pssParameters + if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, ¶ms); err != nil { + return x509.UnknownSignatureAlgorithm + } + + var mgf1HashFunc pkix.AlgorithmIdentifier + if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil { + return x509.UnknownSignatureAlgorithm + } + + // PSS is greatly overburdened with options. This code forces them into + // three buckets by requiring that the MGF1 hash function always match the + // message hash function (as recommended in RFC 3447, Section 8.1), that the + // salt length matches the hash length, and that the trailer field has the + // default value. + if (len(params.Hash.Parameters.FullBytes) != 0 && !bytes.Equal(params.Hash.Parameters.FullBytes, asn1.NullBytes)) || + !params.MGF.Algorithm.Equal(oidMGF1) || + !mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) || + (len(mgf1HashFunc.Parameters.FullBytes) != 0 && !bytes.Equal(mgf1HashFunc.Parameters.FullBytes, asn1.NullBytes)) || + params.TrailerField != 1 { + return x509.UnknownSignatureAlgorithm + } + + switch { + case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32: + return x509.SHA256WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48: + return x509.SHA384WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64: + return x509.SHA512WithRSAPSS + } + + return x509.UnknownSignatureAlgorithm +} + +func parseCertificate(in *certificate) (*x509.Certificate, error) { + out := new(x509.Certificate) + out.Raw = in.Raw + out.RawTBSCertificate = in.TBSCertificate.Raw + out.RawSubjectPublicKeyInfo = in.TBSCertificate.PublicKey.Raw + out.RawSubject = in.TBSCertificate.Subject.FullBytes + out.RawIssuer = in.TBSCertificate.Issuer.FullBytes + + out.Signature = in.SignatureValue.RightAlign() + out.SignatureAlgorithm = getSignatureAlgorithmFromAI(in.TBSCertificate.SignatureAlgorithm) + + out.PublicKeyAlgorithm = + getPublicKeyAlgorithmFromOID(in.TBSCertificate.PublicKey.Algorithm.Algorithm) + var err error + out.PublicKey, err = parsePublicKey(out.PublicKeyAlgorithm, &in.TBSCertificate.PublicKey) + if err != nil { + return nil, err + } + + out.Version = in.TBSCertificate.Version + 1 + out.SerialNumber = in.TBSCertificate.SerialNumber + + var issuer, subject pkix.RDNSequence + if rest, err := asn1.Unmarshal(in.TBSCertificate.Subject.FullBytes, &subject); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 subject") + } + if rest, err := asn1.Unmarshal(in.TBSCertificate.Issuer.FullBytes, &issuer); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 subject") + } + + out.Issuer.FillFromRDNSequence(&issuer) + out.Subject.FillFromRDNSequence(&subject) + + out.NotBefore = in.TBSCertificate.Validity.NotBefore + out.NotAfter = in.TBSCertificate.Validity.NotAfter + + for _, e := range in.TBSCertificate.Extensions { + out.Extensions = append(out.Extensions, e) + unhandled := false + + if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 { + switch e.Id[3] { + case 15: + // RFC 5280, 4.2.1.3 + var usageBits asn1.BitString + if rest, err := asn1.Unmarshal(e.Value, &usageBits); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 KeyUsage") + } + + var usage int + for i := 0; i < 9; i++ { + if usageBits.At(i) != 0 { + usage |= 1 << uint(i) + } + } + out.KeyUsage = x509.KeyUsage(usage) + + case 19: + // RFC 5280, 4.2.1.9 + var constraints basicConstraints + if rest, err := asn1.Unmarshal(e.Value, &constraints); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 BasicConstraints") + } + + out.BasicConstraintsValid = true + out.IsCA = constraints.IsCA + out.MaxPathLen = constraints.MaxPathLen + out.MaxPathLenZero = out.MaxPathLen == 0 + // TODO: map out.MaxPathLen to 0 if it has the -1 default value? (Issue 19285) + case 17: + out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(e.Value) + if err != nil { + return nil, err + } + + if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 && len(out.URIs) == 0 { + // If we didn't parse anything then we do the critical check, below. + unhandled = true + } + + case 30: + unhandled, err = parseNameConstraintsExtension(out, e) + if err != nil { + return nil, err + } + + case 31: + // RFC 5280, 4.2.1.13 + + // CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint + // + // DistributionPoint ::= SEQUENCE { + // distributionPoint [0] DistributionPointName OPTIONAL, + // reasons [1] ReasonFlags OPTIONAL, + // cRLIssuer [2] GeneralNames OPTIONAL } + // + // DistributionPointName ::= CHOICE { + // fullName [0] GeneralNames, + // nameRelativeToCRLIssuer [1] RelativeDistinguishedName } + + var cdp []distributionPoint + if rest, err := asn1.Unmarshal(e.Value, &cdp); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 CRL distribution point") + } + + for _, dp := range cdp { + // Per RFC 5280, 4.2.1.13, one of distributionPoint or cRLIssuer may be empty. + if len(dp.DistributionPoint.FullName) == 0 { + continue + } + + for _, fullName := range dp.DistributionPoint.FullName { + if fullName.Tag == 6 { + out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(fullName.Bytes)) + } + } + } + + case 35: + // RFC 5280, 4.2.1.1 + var a authKeyId + if rest, err := asn1.Unmarshal(e.Value, &a); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 authority key-id") + } + out.AuthorityKeyId = a.Id + + case 37: + // RFC 5280, 4.2.1.12. Extended Key Usage + + // id-ce-extKeyUsage OBJECT IDENTIFIER ::= { id-ce 37 } + // + // ExtKeyUsageSyntax ::= SEQUENCE SIZE (1..MAX) OF KeyPurposeId + // + // KeyPurposeId ::= OBJECT IDENTIFIER + + var keyUsage []asn1.ObjectIdentifier + if rest, err := asn1.Unmarshal(e.Value, &keyUsage); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 ExtendedKeyUsage") + } + + for _, u := range keyUsage { + if extKeyUsage, ok := extKeyUsageFromOID(u); ok { + out.ExtKeyUsage = append(out.ExtKeyUsage, extKeyUsage) + } else { + out.UnknownExtKeyUsage = append(out.UnknownExtKeyUsage, u) + } + } + + case 14: + // RFC 5280, 4.2.1.2 + var keyid []byte + if rest, err := asn1.Unmarshal(e.Value, &keyid); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 key-id") + } + out.SubjectKeyId = keyid + + case 32: + // RFC 5280 4.2.1.4: Certificate Policies + var policies []policyInformation + if rest, err := asn1.Unmarshal(e.Value, &policies); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 certificate policies") + } + out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, len(policies)) + for i, policy := range policies { + out.PolicyIdentifiers[i] = policy.Policy + } + + default: + // Unknown extensions are recorded if critical. + unhandled = true + } + } else if e.Id.Equal(oidExtensionAuthorityInfoAccess) { + // RFC 5280 4.2.2.1: Authority Information Access + var aia []authorityInfoAccess + if rest, err := asn1.Unmarshal(e.Value, &aia); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 authority information") + } + + for _, v := range aia { + // GeneralName: uniformResourceIdentifier [6] IA5String + if v.Location.Tag != 6 { + continue + } + if v.Method.Equal(oidAuthorityInfoAccessOcsp) { + out.OCSPServer = append(out.OCSPServer, string(v.Location.Bytes)) + } else if v.Method.Equal(oidAuthorityInfoAccessIssuers) { + out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(v.Location.Bytes)) + } + } + } else { + // Unknown extensions are recorded if critical. + unhandled = true + } + + if e.Critical && unhandled { + out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id) + } + } + + return out, nil +} + +// ParseCertificate parses a single certificate from the given ASN.1 DER data. +func ParseCertificate(asn1Data []byte) (*Certificate, error) { + var cert certificate + rest, err := asn1.Unmarshal(asn1Data, &cert) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, asn1.SyntaxError{Msg: "trailing data"} + } + + var result *x509.Certificate + result, err = parseCertificate(&cert) + + if err != nil { + return nil, err + } + return &Certificate{*result}, nil +} + +// ParseCertificates parses one or more certificates from the given ASN.1 DER +// data. The certificates must be concatenated with no intermediate padding. +func ParseCertificates(asn1Data []byte) ([]*Certificate, error) { + var v []*certificate + + for len(asn1Data) > 0 { + cert := new(certificate) + var err error + asn1Data, err = asn1.Unmarshal(asn1Data, cert) + if err != nil { + return nil, err + } + v = append(v, cert) + } + + ret := make([]*Certificate, len(v)) + for i, ci := range v { + cert, err := parseCertificate(ci) + if err != nil { + return nil, err + } + ret[i] = &Certificate{*cert} + } + + return ret, nil +} + +// CheckSignatureFrom verifies that the signature on c is a valid signature +// from parent. +func (c *Certificate) CheckSignatureFrom(parent *Certificate) error { + // RFC 5280, 4.2.1.9: + // "If the basic constraints extension is not present in a version 3 + // certificate, or the extension is present but the cA boolean is not + // asserted, then the certified public key MUST NOT be used to verify + // certificate signatures." + if parent.Version == 3 && !parent.BasicConstraintsValid || + parent.BasicConstraintsValid && !parent.IsCA { + return x509.ConstraintViolationError{} + } + + if parent.KeyUsage != 0 && parent.KeyUsage&x509.KeyUsageCertSign == 0 { + return x509.ConstraintViolationError{} + } + + if parent.PublicKeyAlgorithm == x509.UnknownPublicKeyAlgorithm { + return x509.ErrUnsupportedAlgorithm + } + + // TODO(agl): don't ignore the path length constraint. + + return parent.CheckSignature(c.SignatureAlgorithm, c.RawTBSCertificate, c.Signature) +} + +// CheckSignature verifies that signature is a valid signature over signed from +// c's public key. +func (c *Certificate) CheckSignature(algo x509.SignatureAlgorithm, signed, signature []byte) error { + key, ok := c.PublicKey.(*ecdsa.PublicKey) + if !ok { + return c.Certificate.CheckSignature(algo, signed, signature) + } + if key.Curve != sm2.P256() { + return c.Certificate.CheckSignature(algo, signed, signature) + } + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(signature, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return errors.New("x509: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("x509: ECDSA signature contained zero or negative values") + } + if !sm2.VerifyWithSM2(key, nil, signed, ecdsaSig.R, ecdsaSig.S) { + return errors.New("x509: ECDSA verification failure") + } + return nil +} + +// CreateCRL returns a DER encoded CRL, signed by this Certificate, that +// contains the given list of revoked certificates. +func (c *Certificate) CreateCRL(rand io.Reader, priv interface{}, revokedCerts []pkix.RevokedCertificate, now, expiry time.Time) (crlBytes []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + + hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), 0) + if err != nil { + return nil, err + } + + // Force revocation times to UTC per RFC 5280. + revokedCertsUTC := make([]pkix.RevokedCertificate, len(revokedCerts)) + for i, rc := range revokedCerts { + rc.RevocationTime = rc.RevocationTime.UTC() + revokedCertsUTC[i] = rc + } + + tbsCertList := pkix.TBSCertificateList{ + Version: 1, + Signature: signatureAlgorithm, + Issuer: c.Subject.ToRDNSequence(), + ThisUpdate: now.UTC(), + NextUpdate: expiry.UTC(), + RevokedCertificates: revokedCertsUTC, + } + + // Authority Key Id + if len(c.SubjectKeyId) > 0 { + var aki pkix.Extension + aki.Id = oidExtensionAuthorityKeyId + aki.Value, err = asn1.Marshal(authKeyId{Id: c.SubjectKeyId}) + if err != nil { + return + } + tbsCertList.Extensions = append(tbsCertList.Extensions, aki) + } + + tbsCertListContents, err := asn1.Marshal(tbsCertList) + if err != nil { + return + } + + signed := tbsCertListContents + var signature []byte + if signatureAlgorithm.Algorithm.Equal(oidSignatureSM2WithSM3) { + privKey, ok := key.(*sm2.PrivateKey) + if !ok { + ecKey, ok := key.(*ecdsa.PrivateKey) + if ok && ecKey.Curve == sm2.P256() { + privKey, _ = new(sm2.PrivateKey).FromECPrivateKey(ecKey) + } + } + signature, err = privKey.SignWithSM2(rand, nil, signed) + } else { + if hashFunc != 0 { + h := hashFunc.New() + h.Write(signed) + signed = h.Sum(nil) + } + signature, err = key.Sign(rand, signed, hashFunc) + } + if err != nil { + return + } + + return asn1.Marshal(pkix.CertificateList{ + TBSCertList: tbsCertList, + SignatureAlgorithm: signatureAlgorithm, + SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, + }) +} diff --git a/smx509/x509_test.go b/smx509/x509_test.go index ab6d381..751f452 100644 --- a/smx509/x509_test.go +++ b/smx509/x509_test.go @@ -1,8 +1,12 @@ package smx509 import ( + "bytes" "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" "crypto/rand" + "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -11,8 +15,13 @@ import ( "encoding/pem" "errors" "fmt" + "math/big" + "net" + "net/url" + "reflect" "strings" "testing" + "time" "github.com/emmansun/gmsm/sm2" ) @@ -43,6 +52,108 @@ bxIHjKZHc2sztHCXe7cseWGiLq0syg== -----END CERTIFICATE REQUEST----- ` +const pemCertificate = `-----BEGIN CERTIFICATE----- +MIIDATCCAemgAwIBAgIRAKQkkrFx1T/dgB/Go/xBM5swDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjA4MTcyMDM2MDdaFw0xNzA4MTcyMDM2 +MDdaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDAoJtjG7M6InsWwIo+l3qq9u+g2rKFXNu9/mZ24XQ8XhV6PUR+5HQ4 +jUFWC58ExYhottqK5zQtKGkw5NuhjowFUgWB/VlNGAUBHtJcWR/062wYrHBYRxJH +qVXOpYKbIWwFKoXu3hcpg/CkdOlDWGKoZKBCwQwUBhWE7MDhpVdQ+ZljUJWL+FlK +yQK5iRsJd5TGJ6VUzLzdT4fmN2DzeK6GLeyMpVpU3sWV90JJbxWQ4YrzkKzYhMmB +EcpXTG2wm+ujiHU/k2p8zlf8Sm7VBM/scmnMFt0ynNXop4FWvJzEm1G0xD2t+e2I +5Utr04dOZPCgkm++QJgYhtZvgW7ZZiGTAgMBAAGjUjBQMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBsGA1UdEQQUMBKC +EHRlc3QuZXhhbXBsZS5jb20wDQYJKoZIhvcNAQELBQADggEBADpqKQxrthH5InC7 +X96UP0OJCu/lLEMkrjoEWYIQaFl7uLPxKH5AmQPH4lYwF7u7gksR7owVG9QU9fs6 +1fK7II9CVgCd/4tZ0zm98FmU4D0lHGtPARrrzoZaqVZcAvRnFTlPX5pFkPhVjjai +/mkxX9LpD8oK1445DFHxK5UjLMmPIIWd8EOi+v5a+hgGwnJpoW7hntSl8kHMtTmy +fnnktsblSUV4lRCit0ymC7Ojhe+gzCCwkgs5kDzVVag+tnl/0e2DloIjASwOhpbH +KVcg7fBd484ht/sS+l0dsB4KDOSpd8JzVDMF8OZqlaydizoJO0yWr9GbCN1+OKq5 +EhLrEqU= +-----END CERTIFICATE-----` + +var pemPrivateKey = testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXAIBAAKBgQCxoeCUW5KJxNPxMp+KmCxKLc1Zv9Ny+4CFqcUXVUYH69L3mQ7v +IWrJ9GBfcaA7BPQqUlWxWM+OCEQZH1EZNIuqRMNQVuIGCbz5UQ8w6tS0gcgdeGX7 +J7jgCQ4RK3F/PuCM38QBLaHx988qG8NMc6VKErBjctCXFHQt14lerd5KpQIDAQAB +AoGAYrf6Hbk+mT5AI33k2Jt1kcweodBP7UkExkPxeuQzRVe0KVJw0EkcFhywKpr1 +V5eLMrILWcJnpyHE5slWwtFHBG6a5fLaNtsBBtcAIfqTQ0Vfj5c6SzVaJv0Z5rOd +7gQF6isy3t3w9IF3We9wXQKzT6q5ypPGdm6fciKQ8RnzREkCQQDZwppKATqQ41/R +vhSj90fFifrGE6aVKC1hgSpxGQa4oIdsYYHwMzyhBmWW9Xv/R+fPyr8ZwPxp2c12 +33QwOLPLAkEA0NNUb+z4ebVVHyvSwF5jhfJxigim+s49KuzJ1+A2RaSApGyBZiwS +rWvWkB471POAKUYt5ykIWVZ83zcceQiNTwJBAMJUFQZX5GDqWFc/zwGoKkeR49Yi +MTXIvf7Wmv6E++eFcnT461FlGAUHRV+bQQXGsItR/opIG7mGogIkVXa3E1MCQARX +AAA7eoZ9AEHflUeuLn9QJI/r0hyQQLEtrpwv6rDT1GCWaLII5HJ6NUFVf4TTcqxo +6vdM4QGKTJoO+SaCyP0CQFdpcxSAuzpFcKv0IlJ8XzS/cy+mweCMwyJ1PFEc4FX6 +wg/HcAJWY60xZTJDFN+Qfx8ZQvBEin6c2/h+zZi5IVY= +-----END RSA TESTING KEY----- +`) + +const ed25519CRLKey = `-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEINdKh2096vUBYu4EIFpjShsUSh3vimKya1sQ1YTT4RZG +-----END PRIVATE KEY-----` + +const ed25519CRLCertificate = ` +Certificate: +Data: + Version: 3 (0x2) + Serial Number: + 7a:07:a0:9d:14:04:16:fc:1f:d8:e5:fe:d1:1d:1f:8d + Signature Algorithm: ED25519 + Issuer: CN = Ed25519 CRL Test CA + Validity + Not Before: Oct 30 01:20:20 2019 GMT + Not After : Dec 31 23:59:59 9999 GMT + Subject: CN = Ed25519 CRL Test CA + Subject Public Key Info: + Public Key Algorithm: ED25519 + ED25519 Public-Key: + pub: + 95:73:3b:b0:06:2a:31:5a:b6:a7:a6:6e:ef:71:df: + ac:6f:6b:39:03:85:5e:63:4b:f8:a6:0f:68:c6:6f: + 75:21 + X509v3 extensions: + X509v3 Key Usage: critical + Digital Signature, Certificate Sign, CRL Sign + X509v3 Extended Key Usage: + TLS Web Client Authentication, TLS Web Server Authentication, OCSP Signing + X509v3 Basic Constraints: critical + CA:TRUE + X509v3 Subject Key Identifier: + B7:17:DA:16:EA:C5:ED:1F:18:49:44:D3:D2:E3:A0:35:0A:81:93:60 + X509v3 Authority Key Identifier: + keyid:B7:17:DA:16:EA:C5:ED:1F:18:49:44:D3:D2:E3:A0:35:0A:81:93:60 +Signature Algorithm: ED25519 + fc:3e:14:ea:bb:70:c2:6f:38:34:70:bc:c8:a7:f4:7c:0d:1e: + 28:d7:2a:9f:22:8a:45:e8:02:76:84:1e:2d:64:2d:1e:09:b5: + 29:71:1f:95:8a:4e:79:87:51:60:9a:e7:86:40:f6:60:c7:d1: + ee:68:76:17:1d:90:cc:92:93:07 +-----BEGIN CERTIFICATE----- +MIIBijCCATygAwIBAgIQegegnRQEFvwf2OX+0R0fjTAFBgMrZXAwHjEcMBoGA1UE +AxMTRWQyNTUxOSBDUkwgVGVzdCBDQTAgFw0xOTEwMzAwMTIwMjBaGA85OTk5MTIz +MTIzNTk1OVowHjEcMBoGA1UEAxMTRWQyNTUxOSBDUkwgVGVzdCBDQTAqMAUGAytl +cAMhAJVzO7AGKjFatqembu9x36xvazkDhV5jS/imD2jGb3Uho4GNMIGKMA4GA1Ud +DwEB/wQEAwIBhjAnBgNVHSUEIDAeBggrBgEFBQcDAgYIKwYBBQUHAwEGCCsGAQUF +BwMJMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFLcX2hbqxe0fGElE09LjoDUK +gZNgMB8GA1UdIwQYMBaAFLcX2hbqxe0fGElE09LjoDUKgZNgMAUGAytlcANBAPw+ +FOq7cMJvODRwvMin9HwNHijXKp8iikXoAnaEHi1kLR4JtSlxH5WKTnmHUWCa54ZA +9mDH0e5odhcdkMySkwc= +-----END CERTIFICATE-----` + +var testPrivateKey *rsa.PrivateKey + +func init() { + block, _ := pem.Decode([]byte(pemPrivateKey)) + + var err error + if testPrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + panic("Failed to parse private key: " + err.Error()) + } +} + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + func getPublicKey(pemContent []byte) (interface{}, error) { block, _ := pem.Decode(pemContent) if block == nil { @@ -57,7 +168,7 @@ func parseAndCheckCsr(csrblock []byte) error { return err } fmt.Printf("%v\n", csr) - return CheckSignature(csr) + return csr.CheckSignature() } func TestParseCertificateRequest(t *testing.T) { @@ -73,6 +184,7 @@ func TestParseCertificateRequest(t *testing.T) { func TestCreateCertificateRequest(t *testing.T) { priv, _ := sm2.GenerateKey(rand.Reader) + names := pkix.Name{CommonName: "TestName"} var template = x509.CertificateRequest{Subject: names} csrblock, err := CreateCertificateRequest(rand.Reader, &template, priv) @@ -140,3 +252,420 @@ func TestMarshalPKIXPublicKey(t *testing.T) { t.Errorf("expected=%s, result=%s", publicKeyPemFromAliKms, pemContent) } } + +func Test_CreateCertificateRequest(t *testing.T) { + random := rand.Reader + + sm2Priv, err := sm2.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate SM2 key: %s", err) + } + + ecdsa256Priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ECDSA key: %s", err) + } + + ecdsa384Priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ECDSA key: %s", err) + } + + ecdsa521Priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ECDSA key: %s", err) + } + + _, ed25519Priv, err := ed25519.GenerateKey(random) + if err != nil { + t.Fatalf("Failed to generate Ed25519 key: %s", err) + } + + tests := []struct { + name string + priv interface{} + sigAlgo x509.SignatureAlgorithm + }{ + {"RSA", testPrivateKey, x509.SHA1WithRSA}, + {"SM2-256", sm2Priv, -1}, + {"ECDSA-256", ecdsa256Priv, x509.ECDSAWithSHA1}, + {"ECDSA-384", ecdsa384Priv, x509.ECDSAWithSHA1}, + {"ECDSA-521", ecdsa521Priv, x509.ECDSAWithSHA1}, + {"Ed25519", ed25519Priv, x509.PureEd25519}, + } + + for _, test := range tests { + template := x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "test.example.com", + Organization: []string{"Σ Acme Co"}, + }, + SignatureAlgorithm: test.sigAlgo, + DNSNames: []string{"test.example.com"}, + EmailAddresses: []string{"gopher@golang.org"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1).To4(), net.ParseIP("2001:4860:0:2001::68")}, + } + + derBytes, err := CreateCertificateRequest(random, &template, test.priv) + if err != nil { + t.Errorf("%s: failed to create certificate request: %s", test.name, err) + continue + } + + out, err := ParseCertificateRequest(derBytes) + if err != nil { + t.Errorf("%s: failed to create certificate request: %s", test.name, err) + continue + } + + err = out.CheckSignature() + if err != nil { + t.Errorf("%s: failed to check certificate request signature: %s", test.name, err) + continue + } + + if out.Subject.CommonName != template.Subject.CommonName { + t.Errorf("%s: output subject common name and template subject common name don't match", test.name) + } else if len(out.Subject.Organization) != len(template.Subject.Organization) { + t.Errorf("%s: output subject organisation and template subject organisation don't match", test.name) + } else if len(out.DNSNames) != len(template.DNSNames) { + t.Errorf("%s: output DNS names and template DNS names don't match", test.name) + } else if len(out.EmailAddresses) != len(template.EmailAddresses) { + t.Errorf("%s: output email addresses and template email addresses don't match", test.name) + } else if len(out.IPAddresses) != len(template.IPAddresses) { + t.Errorf("%s: output IP addresses and template IP addresses names don't match", test.name) + } + } +} + +func parseCIDR(s string) *net.IPNet { + _, net, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return net +} + +func parseURI(s string) *url.URL { + uri, err := url.Parse(s) + if err != nil { + panic(err) + } + return uri +} + +func TestCreateSelfSignedCertificate(t *testing.T) { + random := rand.Reader + + sm2Priv, err := sm2.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate SM2 key: %s", err) + } + + ecdsaPriv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ECDSA key: %s", err) + } + + ed25519Pub, ed25519Priv, err := ed25519.GenerateKey(random) + if err != nil { + t.Fatalf("Failed to generate Ed25519 key: %s", err) + } + + tests := []struct { + name string + pub, priv interface{} + checkSig bool + sigAlgo x509.SignatureAlgorithm + }{ + {"RSA/RSA", &testPrivateKey.PublicKey, testPrivateKey, true, x509.SHA1WithRSA}, + {"RSA/ECDSA", &testPrivateKey.PublicKey, ecdsaPriv, false, x509.ECDSAWithSHA384}, + {"RSA/SM2", &testPrivateKey.PublicKey, sm2Priv, false, x509.UnknownSignatureAlgorithm}, + {"ECDSA/RSA", &ecdsaPriv.PublicKey, testPrivateKey, false, x509.SHA256WithRSA}, + {"ECDSA/ECDSA", &ecdsaPriv.PublicKey, ecdsaPriv, true, x509.ECDSAWithSHA1}, + {"ECDSA/SM2", &ecdsaPriv.PublicKey, sm2Priv, false, x509.UnknownSignatureAlgorithm}, + {"SM2/ECDSA", &sm2Priv.PublicKey, ecdsaPriv, false, x509.ECDSAWithSHA1}, + {"RSAPSS/RSAPSS", &testPrivateKey.PublicKey, testPrivateKey, true, x509.SHA256WithRSAPSS}, + {"ECDSA/RSAPSS", &ecdsaPriv.PublicKey, testPrivateKey, false, x509.SHA256WithRSAPSS}, + {"SM2/RSAPSS", &sm2Priv.PublicKey, testPrivateKey, false, x509.SHA256WithRSAPSS}, + {"RSAPSS/ECDSA", &testPrivateKey.PublicKey, ecdsaPriv, false, x509.ECDSAWithSHA384}, + {"Ed25519", ed25519Pub, ed25519Priv, true, x509.PureEd25519}, + {"SM2", &sm2Priv.PublicKey, sm2Priv, true, x509.UnknownSignatureAlgorithm}, + } + + testExtKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth} + testUnknownExtKeyUsage := []asn1.ObjectIdentifier{[]int{1, 2, 3}, []int{2, 59, 1}} + extraExtensionData := []byte("extra extension") + + for _, test := range tests { + commonName := "test.example.com" + template := x509.Certificate{ + // SerialNumber is negative to ensure that negative + // values are parsed. This is due to the prevalence of + // buggy code that produces certificates with negative + // serial numbers. + SerialNumber: big.NewInt(-1), + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"Σ Acme Co"}, + Country: []string{"US"}, + ExtraNames: []pkix.AttributeTypeAndValue{ + { + Type: []int{2, 5, 4, 42}, + Value: "Gopher", + }, + // This should override the Country, above. + { + Type: []int{2, 5, 4, 6}, + Value: "NL", + }, + }, + }, + NotBefore: time.Unix(1000, 0), + NotAfter: time.Unix(100000, 0), + + SignatureAlgorithm: test.sigAlgo, + + SubjectKeyId: []byte{1, 2, 3, 4}, + KeyUsage: x509.KeyUsageCertSign, + + ExtKeyUsage: testExtKeyUsage, + UnknownExtKeyUsage: testUnknownExtKeyUsage, + + BasicConstraintsValid: true, + IsCA: true, + + OCSPServer: []string{"http://ocsp.example.com"}, + IssuingCertificateURL: []string{"http://crt.example.com/ca1.crt"}, + + DNSNames: []string{"test.example.com"}, + EmailAddresses: []string{"gopher@golang.org"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1).To4(), net.ParseIP("2001:4860:0:2001::68")}, + URIs: []*url.URL{parseURI("https://foo.com/wibble#foo")}, + + PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + PermittedDNSDomains: []string{".example.com", "example.com"}, + ExcludedDNSDomains: []string{"bar.example.com"}, + PermittedIPRanges: []*net.IPNet{parseCIDR("192.168.1.1/16"), parseCIDR("1.2.3.4/8")}, + ExcludedIPRanges: []*net.IPNet{parseCIDR("2001:db8::/48")}, + PermittedEmailAddresses: []string{"foo@example.com"}, + ExcludedEmailAddresses: []string{".example.com", "example.com"}, + PermittedURIDomains: []string{".bar.com", "bar.com"}, + ExcludedURIDomains: []string{".bar2.com", "bar2.com"}, + + CRLDistributionPoints: []string{"http://crl1.example.com/ca1.crl", "http://crl2.example.com/ca1.crl"}, + + ExtraExtensions: []pkix.Extension{ + { + Id: []int{1, 2, 3, 4}, + Value: extraExtensionData, + }, + // This extension should override the SubjectKeyId, above. + { + Id: oidExtensionSubjectKeyId, + Critical: false, + Value: []byte{0x04, 0x04, 4, 3, 2, 1}, + }, + }, + } + + derBytes, err := CreateCertificate(random, &template, &template, test.pub, test.priv) + if err != nil { + t.Errorf("%s: failed to create certificate: %s", test.name, err) + continue + } + + cert, err := ParseCertificate(derBytes) + if err != nil { + t.Errorf("%s: failed to parse certificate: %s", test.name, err) + continue + } + + if len(cert.PolicyIdentifiers) != 1 || !cert.PolicyIdentifiers[0].Equal(template.PolicyIdentifiers[0]) { + t.Errorf("%s: failed to parse policy identifiers: got:%#v want:%#v", test.name, cert.PolicyIdentifiers, template.PolicyIdentifiers) + } + + if len(cert.PermittedDNSDomains) != 2 || cert.PermittedDNSDomains[0] != ".example.com" || cert.PermittedDNSDomains[1] != "example.com" { + t.Errorf("%s: failed to parse name constraints: %#v", test.name, cert.PermittedDNSDomains) + } + + if len(cert.ExcludedDNSDomains) != 1 || cert.ExcludedDNSDomains[0] != "bar.example.com" { + t.Errorf("%s: failed to parse name constraint exclusions: %#v", test.name, cert.ExcludedDNSDomains) + } + + if len(cert.PermittedIPRanges) != 2 || cert.PermittedIPRanges[0].String() != "192.168.0.0/16" || cert.PermittedIPRanges[1].String() != "1.0.0.0/8" { + t.Errorf("%s: failed to parse IP constraints: %#v", test.name, cert.PermittedIPRanges) + } + + if len(cert.ExcludedIPRanges) != 1 || cert.ExcludedIPRanges[0].String() != "2001:db8::/48" { + t.Errorf("%s: failed to parse IP constraint exclusions: %#v", test.name, cert.ExcludedIPRanges) + } + + if len(cert.PermittedEmailAddresses) != 1 || cert.PermittedEmailAddresses[0] != "foo@example.com" { + t.Errorf("%s: failed to parse permitted email addreses: %#v", test.name, cert.PermittedEmailAddresses) + } + + if len(cert.ExcludedEmailAddresses) != 2 || cert.ExcludedEmailAddresses[0] != ".example.com" || cert.ExcludedEmailAddresses[1] != "example.com" { + t.Errorf("%s: failed to parse excluded email addreses: %#v", test.name, cert.ExcludedEmailAddresses) + } + + if len(cert.PermittedURIDomains) != 2 || cert.PermittedURIDomains[0] != ".bar.com" || cert.PermittedURIDomains[1] != "bar.com" { + t.Errorf("%s: failed to parse permitted URIs: %#v", test.name, cert.PermittedURIDomains) + } + + if len(cert.ExcludedURIDomains) != 2 || cert.ExcludedURIDomains[0] != ".bar2.com" || cert.ExcludedURIDomains[1] != "bar2.com" { + t.Errorf("%s: failed to parse excluded URIs: %#v", test.name, cert.ExcludedURIDomains) + } + + if cert.Subject.CommonName != commonName { + t.Errorf("%s: subject wasn't correctly copied from the template. Got %s, want %s", test.name, cert.Subject.CommonName, commonName) + } + + if len(cert.Subject.Country) != 1 || cert.Subject.Country[0] != "NL" { + t.Errorf("%s: ExtraNames didn't override Country", test.name) + } + + for _, ext := range cert.Extensions { + if ext.Id.Equal(oidExtensionSubjectAltName) { + if ext.Critical { + t.Fatal("SAN extension is marked critical") + } + } + } + + found := false + for _, atv := range cert.Subject.Names { + if atv.Type.Equal([]int{2, 5, 4, 42}) { + found = true + break + } + } + if !found { + t.Errorf("%s: Names didn't contain oid 2.5.4.42 from ExtraNames", test.name) + } + + if cert.Issuer.CommonName != commonName { + t.Errorf("%s: issuer wasn't correctly copied from the template. Got %s, want %s", test.name, cert.Issuer.CommonName, commonName) + } + + if cert.SignatureAlgorithm != test.sigAlgo { + t.Errorf("%s: SignatureAlgorithm wasn't copied from template. Got %v, want %v", test.name, cert.SignatureAlgorithm, test.sigAlgo) + } + + if !reflect.DeepEqual(cert.ExtKeyUsage, testExtKeyUsage) { + t.Errorf("%s: extkeyusage wasn't correctly copied from the template. Got %v, want %v", test.name, cert.ExtKeyUsage, testExtKeyUsage) + } + + if !reflect.DeepEqual(cert.UnknownExtKeyUsage, testUnknownExtKeyUsage) { + t.Errorf("%s: unknown extkeyusage wasn't correctly copied from the template. Got %v, want %v", test.name, cert.UnknownExtKeyUsage, testUnknownExtKeyUsage) + } + + if !reflect.DeepEqual(cert.OCSPServer, template.OCSPServer) { + t.Errorf("%s: OCSP servers differ from template. Got %v, want %v", test.name, cert.OCSPServer, template.OCSPServer) + } + + if !reflect.DeepEqual(cert.IssuingCertificateURL, template.IssuingCertificateURL) { + t.Errorf("%s: Issuing certificate URLs differ from template. Got %v, want %v", test.name, cert.IssuingCertificateURL, template.IssuingCertificateURL) + } + + if !reflect.DeepEqual(cert.DNSNames, template.DNSNames) { + t.Errorf("%s: SAN DNS names differ from template. Got %v, want %v", test.name, cert.DNSNames, template.DNSNames) + } + + if !reflect.DeepEqual(cert.EmailAddresses, template.EmailAddresses) { + t.Errorf("%s: SAN emails differ from template. Got %v, want %v", test.name, cert.EmailAddresses, template.EmailAddresses) + } + + if len(cert.URIs) != 1 || cert.URIs[0].String() != "https://foo.com/wibble#foo" { + t.Errorf("%s: URIs differ from template. Got %v, want %v", test.name, cert.URIs, template.URIs) + } + + if !reflect.DeepEqual(cert.IPAddresses, template.IPAddresses) { + t.Errorf("%s: SAN IPs differ from template. Got %v, want %v", test.name, cert.IPAddresses, template.IPAddresses) + } + + if !reflect.DeepEqual(cert.CRLDistributionPoints, template.CRLDistributionPoints) { + t.Errorf("%s: CRL distribution points differ from template. Got %v, want %v", test.name, cert.CRLDistributionPoints, template.CRLDistributionPoints) + } + + if !bytes.Equal(cert.SubjectKeyId, []byte{4, 3, 2, 1}) { + t.Errorf("%s: ExtraExtensions didn't override SubjectKeyId", test.name) + } + + if !bytes.Contains(derBytes, extraExtensionData) { + t.Errorf("%s: didn't find extra extension in DER output", test.name) + } + + if test.checkSig { + err = cert.CheckSignatureFrom(cert) + if err != nil { + t.Errorf("%s: signature verification failed: %s", test.name, err) + } + } + } +} + +func TestCRLCreation(t *testing.T) { + block, _ := pem.Decode([]byte(pemPrivateKey)) + privRSA, _ := x509.ParsePKCS1PrivateKey(block.Bytes) + block, _ = pem.Decode([]byte(pemCertificate)) + certRSA, _ := ParseCertificate(block.Bytes) + + block, _ = pem.Decode([]byte(ed25519CRLKey)) + privEd25519, _ := ParsePKCS8PrivateKey(block.Bytes) + block, _ = pem.Decode([]byte(ed25519CRLCertificate)) + certEd25519, _ := ParseCertificate(block.Bytes) + + tests := []struct { + name string + priv interface{} + cert *Certificate + }{ + {"RSA CA", privRSA, certRSA}, + {"Ed25519 CA", privEd25519, certEd25519}, + } + + loc := time.FixedZone("Oz/Atlantis", int((2 * time.Hour).Seconds())) + + now := time.Unix(1000, 0).In(loc) + nowUTC := now.UTC() + expiry := time.Unix(10000, 0) + + revokedCerts := []pkix.RevokedCertificate{ + { + SerialNumber: big.NewInt(1), + RevocationTime: nowUTC, + }, + { + SerialNumber: big.NewInt(42), + // RevocationTime should be converted to UTC before marshaling. + RevocationTime: now, + }, + } + expectedCerts := []pkix.RevokedCertificate{ + { + SerialNumber: big.NewInt(1), + RevocationTime: nowUTC, + }, + { + SerialNumber: big.NewInt(42), + RevocationTime: nowUTC, + }, + } + + for _, test := range tests { + crlBytes, err := test.cert.CreateCRL(rand.Reader, test.priv, revokedCerts, now, expiry) + if err != nil { + t.Errorf("%s: error creating CRL: %s", test.name, err) + } + + parsedCRL, err := x509.ParseDERCRL(crlBytes) + if err != nil { + t.Errorf("%s: error reparsing CRL: %s", test.name, err) + } + if !reflect.DeepEqual(parsedCRL.TBSCertList.RevokedCertificates, expectedCerts) { + t.Errorf("%s: RevokedCertificates mismatch: got %v; want %v.", test.name, + parsedCRL.TBSCertList.RevokedCertificates, expectedCerts) + } + } +}