smx509: [sync sdk] crypto/x509: implement AddCertWithConstraint #208

This commit is contained in:
Sun Yimin 2024-03-06 08:35:14 +08:00 committed by GitHub
parent 5adc912824
commit 3a2c7e2c9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 163 additions and 34 deletions

View File

@ -40,6 +40,11 @@ type lazyCert struct {
// fewer allocations. // fewer allocations.
rawSubject []byte rawSubject []byte
// constraint is a function to run against a chain when it is a candidate to
// be added to the chain. This allows adding arbitrary constraints that are
// not specified in the certificate itself.
constraint func([]*Certificate) error
// getCert returns the certificate. // getCert returns the certificate.
// //
// It is not meant to do network operations or anything else // It is not meant to do network operations or anything else
@ -69,8 +74,9 @@ func (s *CertPool) len() int {
} }
// cert returns cert index n in s. // cert returns cert index n in s.
func (s *CertPool) cert(n int) (*Certificate, error) { func (s *CertPool) cert(n int) (*Certificate, func([]*Certificate) error, error) {
return s.lazyCerts[n].getCert() cert, err := s.lazyCerts[n].getCert()
return cert, s.lazyCerts[n].constraint, err
} }
// Clone returns a copy of s. // Clone returns a copy of s.
@ -112,9 +118,14 @@ func SystemCertPool() (*CertPool, error) {
return loadSystemRoots() return loadSystemRoots()
} }
type potentialParent struct {
cert *Certificate
constraint func([]*Certificate) error
}
// findPotentialParents returns the indexes of certificates in s which might // findPotentialParents returns the indexes of certificates in s which might
// have signed cert. // have signed cert.
func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate { func (s *CertPool) findPotentialParents(cert *Certificate) []potentialParent {
if s == nil { if s == nil {
return nil return nil
} }
@ -125,21 +136,21 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate {
// AKID and SKID match // AKID and SKID match
// AKID present, SKID missing / AKID missing, SKID present // AKID present, SKID missing / AKID missing, SKID present
// AKID and SKID don't match // AKID and SKID don't match
var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate var matchingKeyID, oneKeyID, mismatchKeyID []potentialParent
for _, c := range s.byName[string(cert.RawIssuer)] { for _, c := range s.byName[string(cert.RawIssuer)] {
candidate, err := s.cert(c) candidate, constraint, err := s.cert(c)
if err != nil { if err != nil {
continue continue
} }
kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId) kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
switch { switch {
case kidMatch: case kidMatch:
matchingKeyID = append(matchingKeyID, candidate) matchingKeyID = append(matchingKeyID, potentialParent{candidate, constraint})
case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) || case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
(len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0): (len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
oneKeyID = append(oneKeyID, candidate) oneKeyID = append(oneKeyID, potentialParent{candidate, constraint})
default: default:
mismatchKeyID = append(mismatchKeyID, candidate) mismatchKeyID = append(mismatchKeyID, potentialParent{candidate, constraint})
} }
} }
@ -147,7 +158,7 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate {
if found == 0 { if found == 0 {
return nil return nil
} }
candidates := make([]*Certificate, 0, found) candidates := make([]potentialParent, 0, found)
candidates = append(candidates, matchingKeyID...) candidates = append(candidates, matchingKeyID...)
candidates = append(candidates, oneKeyID...) candidates = append(candidates, oneKeyID...)
candidates = append(candidates, mismatchKeyID...) candidates = append(candidates, mismatchKeyID...)
@ -168,7 +179,7 @@ func (s *CertPool) AddCert(cert *Certificate) {
} }
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) { s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil return cert, nil
}) }, nil)
} }
// addCertFunc adds metadata about a certificate to a pool, along with // addCertFunc adds metadata about a certificate to a pool, along with
@ -176,7 +187,7 @@ func (s *CertPool) AddCert(cert *Certificate) {
// //
// The rawSubject is Certificate.RawSubject and must be non-empty. // The rawSubject is Certificate.RawSubject and must be non-empty.
// The getCert func may be called 0 or more times. // The getCert func may be called 0 or more times.
func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) { func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error), constraint func([]*Certificate) error) {
if getCert == nil { if getCert == nil {
panic("getCert can't be nil") panic("getCert can't be nil")
} }
@ -190,6 +201,7 @@ func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func
s.lazyCerts = append(s.lazyCerts, lazyCert{ s.lazyCerts = append(s.lazyCerts, lazyCert{
rawSubject: []byte(rawSubject), rawSubject: []byte(rawSubject),
getCert: getCert, getCert: getCert,
constraint: constraint,
}) })
s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1) s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
} }
@ -227,7 +239,7 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
certBytes = nil certBytes = nil
}) })
return lazyCert.v, nil return lazyCert.v, nil
}) }, nil)
ok = true ok = true
} }
@ -262,3 +274,17 @@ func (s *CertPool) Equal(other *CertPool) bool {
} }
return true return true
} }
// AddCertWithConstraint adds a certificate to the pool with the additional
// constraint. When Certificate.Verify builds a chain which is rooted by cert,
// it will additionally pass the whole chain to constraint to determine its
// validity. If constraint returns a non-nil error, the chain will be discarded.
// constraint may be called concurrently from multiple goroutines.
func (s *CertPool) AddCertWithConstraint(cert *Certificate, constraint func([]*Certificate) error) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil
}, constraint)
}

View File

@ -40,7 +40,7 @@ func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertCo
if opts.Intermediates != nil { if opts.Intermediates != nil {
for i := 0; i < opts.Intermediates.len(); i++ { for i := 0; i < opts.Intermediates.len(); i++ {
intermediate, err := opts.Intermediates.cert(i) intermediate, _, err := opts.Intermediates.cert(i)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -632,7 +632,7 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
return nil, errNotParsed return nil, errNotParsed
} }
for i := 0; i < opts.Intermediates.len(); i++ { for i := 0; i < opts.Intermediates.len(); i++ {
c, err := opts.Intermediates.cert(i) c, _, err := opts.Intermediates.cert(i)
if err != nil { if err != nil {
return nil, fmt.Errorf("x509: error fetching intermediate: %w", err) return nil, fmt.Errorf("x509: error fetching intermediate: %w", err)
} }
@ -775,8 +775,8 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o
hintCert *Certificate hintCert *Certificate
) )
considerCandidate := func(certType int, candidate *Certificate) { considerCandidate := func(certType int, candidate potentialParent) {
if alreadyInChain(candidate, currentChain) { if candidate.cert.PublicKey == nil ||alreadyInChain(candidate.cert, currentChain) {
return return
} }
@ -789,29 +789,39 @@ func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, o
return return
} }
if err := c.CheckSignatureFrom(candidate); err != nil { if err := c.CheckSignatureFrom(candidate.cert); err != nil {
if hintErr == nil { if hintErr == nil {
hintErr = err hintErr = err
hintCert = candidate hintCert = candidate.cert
} }
return return
} }
err = candidate.isValid(certType, currentChain, opts) err = candidate.cert.isValid(certType, currentChain, opts)
if err != nil { if err != nil {
if hintErr == nil { if hintErr == nil {
hintErr = err hintErr = err
hintCert = candidate hintCert = candidate.cert
} }
return return
} }
if candidate.constraint != nil {
if err := candidate.constraint(currentChain); err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate.cert
}
return
}
}
switch certType { switch certType {
case rootCertificate: case rootCertificate:
chains = append(chains, appendToFreshChain(currentChain, candidate)) chains = append(chains, appendToFreshChain(currentChain, candidate.cert))
case intermediateCertificate: case intermediateCertificate:
var childChains [][]*Certificate var childChains [][]*Certificate
childChains, err = candidate.buildChains(appendToFreshChain(currentChain, candidate), sigChecks, opts) childChains, err = candidate.cert.buildChains(appendToFreshChain(currentChain, candidate.cert), sigChecks, opts)
chains = append(chains, childChains...) chains = append(chains, childChains...)
} }
} }

View File

@ -1915,11 +1915,13 @@ type trustGraphEdge struct {
Subject string Subject string
Type int Type int
MutateTemplate func(*Certificate) MutateTemplate func(*Certificate)
Constraint func([]*Certificate) error
} }
type rootDescription struct { type rootDescription struct {
Subject string Subject string
MutateTemplate func(*Certificate) MutateTemplate func(*Certificate)
Constraint func([]*Certificate) error
} }
type trustGraphDescription struct { type trustGraphDescription struct {
@ -1972,19 +1974,23 @@ func buildTrustGraph(t *testing.T, d trustGraphDescription) (*CertPool, *CertPoo
certs := map[string]*Certificate{} certs := map[string]*Certificate{}
keys := map[string]crypto.Signer{} keys := map[string]crypto.Signer{}
roots := []*Certificate{} rootPool := NewCertPool()
for _, r := range d.Roots { for _, r := range d.Roots {
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
t.Fatalf("failed to generate test key: %s", err) t.Fatalf("failed to generate test key: %s", err)
} }
root := genCertEdge(t, r.Subject, k, r.MutateTemplate, rootCertificate, nil, nil) root := genCertEdge(t, r.Subject, k, r.MutateTemplate, rootCertificate, nil, nil)
roots = append(roots, root) if r.Constraint != nil {
rootPool.AddCertWithConstraint(root, r.Constraint)
} else {
rootPool.AddCert(root)
}
certs[r.Subject] = root certs[r.Subject] = root
keys[r.Subject] = k keys[r.Subject] = k
} }
intermediates := []*Certificate{} intermediatePool := NewCertPool()
var leaf *Certificate var leaf *Certificate
for _, e := range d.Graph { for _, e := range d.Graph {
issuerCert, ok := certs[e.Issuer] issuerCert, ok := certs[e.Issuer]
@ -2010,18 +2016,14 @@ func buildTrustGraph(t *testing.T, d trustGraphDescription) (*CertPool, *CertPoo
if e.Subject == d.Leaf { if e.Subject == d.Leaf {
leaf = cert leaf = cert
} else { } else {
intermediates = append(intermediates, cert) if e.Constraint != nil {
intermediatePool.AddCertWithConstraint(cert, e.Constraint)
} else {
intermediatePool.AddCert(cert)
}
} }
} }
rootPool, intermediatePool := NewCertPool(), NewCertPool()
for i := len(roots) - 1; i >= 0; i-- {
rootPool.AddCert(roots[i])
}
for i := len(intermediates) - 1; i >= 0; i-- {
intermediatePool.AddCert(intermediates[i])
}
return rootPool, intermediatePool, leaf return rootPool, intermediatePool, leaf
} }
@ -2476,6 +2478,78 @@ func TestPathBuilding(t *testing.T) {
}, },
}, },
expectedChains: []string{"CN=leaf -> CN=inter -> CN=root"}, expectedChains: []string{"CN=leaf -> CN=inter -> CN=root"},
},
{
// A code constraint on the root, applying to one of two intermediates in the graph, should
// result in only one valid chain.
name: "code constrained root, two paths, one valid",
graph: trustGraphDescription{
Roots: []rootDescription{{Subject: "root", Constraint: func(chain []*Certificate) error {
for _, c := range chain {
if c.Subject.CommonName == "inter a" {
return errors.New("bad")
}
}
return nil
}}},
Leaf: "leaf",
Graph: []trustGraphEdge{
{
Issuer: "root",
Subject: "inter a",
Type: intermediateCertificate,
},
{
Issuer: "root",
Subject: "inter b",
Type: intermediateCertificate,
},
{
Issuer: "inter a",
Subject: "inter c",
Type: intermediateCertificate,
},
{
Issuer: "inter b",
Subject: "inter c",
Type: intermediateCertificate,
},
{
Issuer: "inter c",
Subject: "leaf",
Type: leafCertificate,
},
},
},
expectedChains: []string{"CN=leaf -> CN=inter c -> CN=inter b -> CN=root"},
},
{
// A code constraint on the root, applying to the only path, should result in an error.
name: "code constrained root, one invalid path",
graph: trustGraphDescription{
Roots: []rootDescription{{Subject: "root", Constraint: func(chain []*Certificate) error {
for _, c := range chain {
if c.Subject.CommonName == "leaf" {
return errors.New("bad")
}
}
return nil
}}},
Leaf: "leaf",
Graph: []trustGraphEdge{
{
Issuer: "root",
Subject: "inter",
Type: intermediateCertificate,
},
{
Issuer: "inter",
Subject: "leaf",
Type: leafCertificate,
},
},
},
expectedErr: "x509: certificate signed by unknown authority (possibly because of \"bad\" while trying to verify candidate authority certificate \"root\")",
}, },
} }
@ -2690,3 +2764,22 @@ func TestVerifyEKURootAsLeaf(t *testing.T) {
} }
} }
func TestVerifyNilPubKey(t *testing.T) {
c := &Certificate{
RawIssuer: []byte{1, 2, 3},
AuthorityKeyId: []byte{1, 2, 3},
}
opts := &VerifyOptions{}
opts.Roots = NewCertPool()
r := &Certificate{
RawSubject: []byte{1, 2, 3},
SubjectKeyId: []byte{1, 2, 3},
}
opts.Roots.AddCert(r)
_, err := c.buildChains([]*Certificate{r}, nil, opts)
if _, ok := err.(UnknownAuthorityError); !ok {
t.Fatalf("buildChains returned unexpected error, got: %v, want %v", err, UnknownAuthorityError{})
}
}