smx509: implement SetFallbackRoots #211

This commit is contained in:
Sun Yimin 2024-03-06 13:02:56 +08:00 committed by GitHub
parent 3a2c7e2c9b
commit 66c05083bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 161 additions and 5 deletions

View File

@ -376,17 +376,17 @@ func parseKeyUsageExtension(der cryptobyte.String) (KeyUsage, error) {
func parseBasicConstraintsExtension(der cryptobyte.String) (bool, int, error) { func parseBasicConstraintsExtension(der cryptobyte.String) (bool, int, error) {
var isCA bool var isCA bool
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return false, 0, errors.New("x509: invalid basic constraints a") return false, 0, errors.New("x509: invalid basic constraints")
} }
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) { if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&isCA) { if !der.ReadASN1Boolean(&isCA) {
return false, 0, errors.New("x509: invalid basic constraints b") return false, 0, errors.New("x509: invalid basic constraints")
} }
} }
maxPathLen := -1 maxPathLen := -1
if der.PeekASN1Tag(cryptobyte_asn1.INTEGER) { if der.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
if !der.ReadASN1Integer(&maxPathLen) { if !der.ReadASN1Integer(&maxPathLen) {
return false, 0, errors.New("x509: invalid basic constraints c") return false, 0, errors.New("x509: invalid basic constraints")
} }
} }

View File

@ -1,21 +1,66 @@
package smx509 package smx509
import "sync" import (
"sync"
"github.com/emmansun/gmsm/internal/godebug"
)
var ( var (
once sync.Once once sync.Once
systemRootsMu sync.RWMutex
systemRoots *CertPool systemRoots *CertPool
systemRootsErr error systemRootsErr error
fallbacksSet bool
) )
func systemRootsPool() *CertPool { func systemRootsPool() *CertPool {
once.Do(initSystemRoots) once.Do(initSystemRoots)
systemRootsMu.RLock()
defer systemRootsMu.RUnlock()
return systemRoots return systemRoots
} }
func initSystemRoots() { func initSystemRoots() {
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
systemRoots, systemRootsErr = loadSystemRoots() systemRoots, systemRootsErr = loadSystemRoots()
if systemRootsErr != nil { if systemRootsErr != nil {
systemRoots = nil systemRoots = nil
} }
} }
// SetFallbackRoots sets the roots to use during certificate verification, if no
// custom roots are specified and a platform verifier or a system certificate
// pool is not available (for instance in a container which does not have a root
// certificate bundle). SetFallbackRoots will panic if roots is nil.
//
// SetFallbackRoots may only be called once, if called multiple times it will
// panic.
//
// The fallback behavior can be forced on all platforms, even when there is a
// system certificate pool, by setting GODEBUG=x509usefallbackroots=1 (note that
// on Windows and macOS this will disable usage of the platform verification
// APIs and cause the pure Go verifier to be used). Setting
// x509usefallbackroots=1 without calling SetFallbackRoots has no effect.
func SetFallbackRoots(roots *CertPool) {
if roots == nil {
panic("roots must be non-nil")
}
// trigger initSystemRoots if it hasn't already been called before we
// take the lock
_ = systemRootsPool()
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
if fallbacksSet {
panic("SetFallbackRoots has already been called")
}
fallbacksSet = true
if systemRoots != nil && (systemRoots.len() > 0 || systemRoots.systemPool) && (godebug.Get("x509usefallbackroots") != "1") {
return
}
systemRoots, systemRootsErr = roots, nil
}

108
smx509/root_test.go Normal file
View File

@ -0,0 +1,108 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package smx509
import (
"testing"
)
func TestFallbackPanic(t *testing.T) {
defer func() {
if recover() == nil {
t.Fatal("Multiple calls to SetFallbackRoots should panic")
}
}()
SetFallbackRoots(nil)
SetFallbackRoots(nil)
}
func TestFallback(t *testing.T) {
// call systemRootsPool so that the sync.Once is triggered, and we can
// manipulate systemRoots without worrying about our working being overwritten
systemRootsPool()
if systemRoots != nil {
originalSystemRoots := *systemRoots
defer func() { systemRoots = &originalSystemRoots }()
}
tests := []struct {
name string
systemRoots *CertPool
systemPool bool
poolContent []*Certificate
forceFallback bool
returnsFallback bool
}{
{
name: "nil systemRoots",
returnsFallback: true,
},
{
name: "empty systemRoots",
systemRoots: NewCertPool(),
returnsFallback: true,
},
{
name: "empty systemRoots system pool",
systemRoots: NewCertPool(),
systemPool: true,
},
{
name: "filled systemRoots system pool",
systemRoots: NewCertPool(),
poolContent: []*Certificate{{}},
systemPool: true,
},
{
name: "filled systemRoots",
systemRoots: NewCertPool(),
poolContent: []*Certificate{{}},
},
{
name: "filled systemRoots, force fallback",
systemRoots: NewCertPool(),
poolContent: []*Certificate{{}},
forceFallback: true,
returnsFallback: true,
},
{
name: "filled systemRoot system pool, force fallback",
systemRoots: NewCertPool(),
poolContent: []*Certificate{{}},
systemPool: true,
forceFallback: true,
returnsFallback: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
fallbacksSet = false
systemRoots = tc.systemRoots
if systemRoots != nil {
systemRoots.systemPool = tc.systemPool
}
for _, c := range tc.poolContent {
systemRoots.AddCert(c)
}
if tc.forceFallback {
t.Setenv("GODEBUG", "x509usefallbackroots=1")
} else {
t.Setenv("GODEBUG", "x509usefallbackroots=0")
}
fallbackPool := NewCertPool()
SetFallbackRoots(fallbackPool)
systemPoolIsFallback := systemRoots == fallbackPool
if tc.returnsFallback && !systemPoolIsFallback {
t.Error("systemRoots was not set to fallback pool")
} else if !tc.returnsFallback && systemPoolIsFallback {
t.Error("systemRoots was set to fallback pool when it shouldn't have been")
}
})
}
}

View File

@ -643,7 +643,10 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
// Use platform verifiers, where available, if Roots is from SystemCertPool. // Use platform verifiers, where available, if Roots is from SystemCertPool.
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
if opts.Roots == nil { // Don't use the system verifier if the system pool was replaced with a non-system pool,
// i.e. if SetFallbackRoots was called with x509usefallbackroots=1.
systemPool := systemRootsPool()
if opts.Roots == nil && (systemPool == nil || systemPool.systemPool) {
return c.systemVerify(&opts) return c.systemVerify(&opts)
} }
if opts.Roots != nil && opts.Roots.systemPool { if opts.Roots != nil && opts.Roots.systemPool {