diff --git a/smx509/cert_pool.go b/smx509/cert_pool.go index cf1dd52..d8db37a 100644 --- a/smx509/cert_pool.go +++ b/smx509/cert_pool.go @@ -40,6 +40,11 @@ type lazyCert struct { // fewer allocations. rawSubject []byte + // constraint is a function to run against a chain when it is a candidate to + // be added to the chain. This allows adding arbitrary constraints that are + // not specified in the certificate itself. + constraint func([]*Certificate) error + // getCert returns the certificate. // // It is not meant to do network operations or anything else @@ -69,8 +74,9 @@ func (s *CertPool) len() int { } // cert returns cert index n in s. -func (s *CertPool) cert(n int) (*Certificate, error) { - return s.lazyCerts[n].getCert() +func (s *CertPool) cert(n int) (*Certificate, func([]*Certificate) error, error) { + cert, err := s.lazyCerts[n].getCert() + return cert, s.lazyCerts[n].constraint, err } // Clone returns a copy of s. @@ -112,9 +118,14 @@ func SystemCertPool() (*CertPool, error) { return loadSystemRoots() } +type potentialParent struct { + cert *Certificate + constraint func([]*Certificate) error +} + // findPotentialParents returns the indexes of certificates in s which might // have signed cert. -func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate { +func (s *CertPool) findPotentialParents(cert *Certificate) []potentialParent { if s == nil { return nil } @@ -125,21 +136,21 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate { // AKID and SKID match // AKID present, SKID missing / AKID missing, SKID present // AKID and SKID don't match - var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate + var matchingKeyID, oneKeyID, mismatchKeyID []potentialParent for _, c := range s.byName[string(cert.RawIssuer)] { - candidate, err := s.cert(c) + candidate, constraint, err := s.cert(c) if err != nil { continue } kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId) switch { case kidMatch: - matchingKeyID = append(matchingKeyID, candidate) + matchingKeyID = append(matchingKeyID, potentialParent{candidate, constraint}) case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) || (len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0): - oneKeyID = append(oneKeyID, candidate) + oneKeyID = append(oneKeyID, potentialParent{candidate, constraint}) default: - mismatchKeyID = append(mismatchKeyID, candidate) + mismatchKeyID = append(mismatchKeyID, potentialParent{candidate, constraint}) } } @@ -147,7 +158,7 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate { if found == 0 { return nil } - candidates := make([]*Certificate, 0, found) + candidates := make([]potentialParent, 0, found) candidates = append(candidates, matchingKeyID...) candidates = append(candidates, oneKeyID...) candidates = append(candidates, mismatchKeyID...) @@ -168,7 +179,7 @@ func (s *CertPool) AddCert(cert *Certificate) { } s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) { return cert, nil - }) + }, nil) } // addCertFunc adds metadata about a certificate to a pool, along with @@ -176,7 +187,7 @@ func (s *CertPool) AddCert(cert *Certificate) { // // The rawSubject is Certificate.RawSubject and must be non-empty. // The getCert func may be called 0 or more times. -func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) { +func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error), constraint func([]*Certificate) error) { if getCert == nil { panic("getCert can't be nil") } @@ -190,6 +201,7 @@ func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func s.lazyCerts = append(s.lazyCerts, lazyCert{ rawSubject: []byte(rawSubject), getCert: getCert, + constraint: constraint, }) s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1) } @@ -227,7 +239,7 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { certBytes = nil }) return lazyCert.v, nil - }) + }, nil) ok = true } @@ -262,3 +274,17 @@ func (s *CertPool) Equal(other *CertPool) bool { } return true } + +// AddCertWithConstraint adds a certificate to the pool with the additional +// constraint. When Certificate.Verify builds a chain which is rooted by cert, +// it will additionally pass the whole chain to constraint to determine its +// validity. If constraint returns a non-nil error, the chain will be discarded. +// constraint may be called concurrently from multiple goroutines. +func (s *CertPool) AddCertWithConstraint(cert *Certificate, constraint func([]*Certificate) error) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) { + return cert, nil + }, constraint) +} diff --git a/smx509/root_windows.go b/smx509/root_windows.go index 5b0e3a2..af035eb 100644 --- a/smx509/root_windows.go +++ b/smx509/root_windows.go @@ -40,7 +40,7 @@ func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertCo if opts.Intermediates != nil { for i := 0; i < opts.Intermediates.len(); i++ { - intermediate, err := opts.Intermediates.cert(i) + intermediate, _, err := opts.Intermediates.cert(i) if err != nil { return nil, err } diff --git a/smx509/verify.go b/smx509/verify.go index 9c5a3c1..84fada9 100644 --- a/smx509/verify.go +++ b/smx509/verify.go @@ -632,7 +632,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e return nil, errNotParsed } for i := 0; i < opts.Intermediates.len(); i++ { - c, err := opts.Intermediates.cert(i) + c, _, err := opts.Intermediates.cert(i) if err != nil { return nil, fmt.Errorf("x509: error fetching intermediate: %w", err) } @@ -775,8 +775,8 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o hintCert *Certificate ) - considerCandidate := func(certType int, candidate *Certificate) { - if alreadyInChain(candidate, currentChain) { + considerCandidate := func(certType int, candidate potentialParent) { + if candidate.cert.PublicKey == nil ||alreadyInChain(candidate.cert, currentChain) { return } @@ -789,29 +789,39 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o return } - if err := c.CheckSignatureFrom(candidate); err != nil { + if err := c.CheckSignatureFrom(candidate.cert); err != nil { if hintErr == nil { hintErr = err - hintCert = candidate + hintCert = candidate.cert } return } - err = candidate.isValid(certType, currentChain, opts) + err = candidate.cert.isValid(certType, currentChain, opts) if err != nil { if hintErr == nil { hintErr = err - hintCert = candidate + hintCert = candidate.cert } return } + if candidate.constraint != nil { + if err := candidate.constraint(currentChain); err != nil { + if hintErr == nil { + hintErr = err + hintCert = candidate.cert + } + return + } + } + switch certType { case rootCertificate: - chains = append(chains, appendToFreshChain(currentChain, candidate)) + chains = append(chains, appendToFreshChain(currentChain, candidate.cert)) case intermediateCertificate: var childChains [][]*Certificate - childChains, err = candidate.buildChains(appendToFreshChain(currentChain, candidate), sigChecks, opts) + childChains, err = candidate.cert.buildChains(appendToFreshChain(currentChain, candidate.cert), sigChecks, opts) chains = append(chains, childChains...) } } diff --git a/smx509/verify_test.go b/smx509/verify_test.go index 2e45ec3..44e6e1c 100644 --- a/smx509/verify_test.go +++ b/smx509/verify_test.go @@ -1915,11 +1915,13 @@ type trustGraphEdge struct { Subject string Type int MutateTemplate func(*Certificate) + Constraint func([]*Certificate) error } type rootDescription struct { Subject string MutateTemplate func(*Certificate) + Constraint func([]*Certificate) error } type trustGraphDescription struct { @@ -1972,19 +1974,23 @@ func buildTrustGraph(t *testing.T, d trustGraphDescription) (*CertPool, *CertPoo certs := map[string]*Certificate{} keys := map[string]crypto.Signer{} - roots := []*Certificate{} + rootPool := NewCertPool() for _, r := range d.Roots { k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("failed to generate test key: %s", err) } root := genCertEdge(t, r.Subject, k, r.MutateTemplate, rootCertificate, nil, nil) - roots = append(roots, root) + if r.Constraint != nil { + rootPool.AddCertWithConstraint(root, r.Constraint) + } else { + rootPool.AddCert(root) + } certs[r.Subject] = root keys[r.Subject] = k } - intermediates := []*Certificate{} + intermediatePool := NewCertPool() var leaf *Certificate for _, e := range d.Graph { issuerCert, ok := certs[e.Issuer] @@ -2010,18 +2016,14 @@ func buildTrustGraph(t *testing.T, d trustGraphDescription) (*CertPool, *CertPoo if e.Subject == d.Leaf { leaf = cert } else { - intermediates = append(intermediates, cert) + if e.Constraint != nil { + intermediatePool.AddCertWithConstraint(cert, e.Constraint) + } else { + intermediatePool.AddCert(cert) + } } } - rootPool, intermediatePool := NewCertPool(), NewCertPool() - for i := len(roots) - 1; i >= 0; i-- { - rootPool.AddCert(roots[i]) - } - for i := len(intermediates) - 1; i >= 0; i-- { - intermediatePool.AddCert(intermediates[i]) - } - return rootPool, intermediatePool, leaf } @@ -2476,6 +2478,78 @@ func TestPathBuilding(t *testing.T) { }, }, expectedChains: []string{"CN=leaf -> CN=inter -> CN=root"}, + }, + { + // A code constraint on the root, applying to one of two intermediates in the graph, should + // result in only one valid chain. + name: "code constrained root, two paths, one valid", + graph: trustGraphDescription{ + Roots: []rootDescription{{Subject: "root", Constraint: func(chain []*Certificate) error { + for _, c := range chain { + if c.Subject.CommonName == "inter a" { + return errors.New("bad") + } + } + return nil + }}}, + Leaf: "leaf", + Graph: []trustGraphEdge{ + { + Issuer: "root", + Subject: "inter a", + Type: intermediateCertificate, + }, + { + Issuer: "root", + Subject: "inter b", + Type: intermediateCertificate, + }, + { + Issuer: "inter a", + Subject: "inter c", + Type: intermediateCertificate, + }, + { + Issuer: "inter b", + Subject: "inter c", + Type: intermediateCertificate, + }, + { + Issuer: "inter c", + Subject: "leaf", + Type: leafCertificate, + }, + }, + }, + expectedChains: []string{"CN=leaf -> CN=inter c -> CN=inter b -> CN=root"}, + }, + { + // A code constraint on the root, applying to the only path, should result in an error. + name: "code constrained root, one invalid path", + graph: trustGraphDescription{ + Roots: []rootDescription{{Subject: "root", Constraint: func(chain []*Certificate) error { + for _, c := range chain { + if c.Subject.CommonName == "leaf" { + return errors.New("bad") + } + } + return nil + }}}, + Leaf: "leaf", + Graph: []trustGraphEdge{ + { + Issuer: "root", + Subject: "inter", + Type: intermediateCertificate, + }, + { + Issuer: "inter", + Subject: "leaf", + Type: leafCertificate, + }, + }, + }, + expectedErr: "x509: certificate signed by unknown authority (possibly because of \"bad\" while trying to verify candidate authority certificate \"root\")", }, } @@ -2690,3 +2764,22 @@ func TestVerifyEKURootAsLeaf(t *testing.T) { } } + +func TestVerifyNilPubKey(t *testing.T) { + c := &Certificate{ + RawIssuer: []byte{1, 2, 3}, + AuthorityKeyId: []byte{1, 2, 3}, + } + opts := &VerifyOptions{} + opts.Roots = NewCertPool() + r := &Certificate{ + RawSubject: []byte{1, 2, 3}, + SubjectKeyId: []byte{1, 2, 3}, + } + opts.Roots.AddCert(r) + + _, err := c.buildChains([]*Certificate{r}, nil, opts) + if _, ok := err.(UnknownAuthorityError); !ok { + t.Fatalf("buildChains returned unexpected error, got: %v, want %v", err, UnknownAuthorityError{}) + } +}