diff --git a/smx509/cert_pool.go b/smx509/cert_pool.go index 2185f26..df04f61 100644 --- a/smx509/cert_pool.go +++ b/smx509/cert_pool.go @@ -244,4 +244,17 @@ func (s *CertPool) Subjects() [][]byte { res[i] = lc.rawSubject } return res -} \ No newline at end of file +} + +// Equal reports whether s and other are equal. +func (s *CertPool) Equal(other *CertPool) bool { + if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) { + return false + } + for h := range s.haveSum { + if !other.haveSum[h] { + return false + } + } + return true +} diff --git a/smx509/cert_pool_test.go b/smx509/cert_pool_test.go new file mode 100644 index 0000000..700f2ed --- /dev/null +++ b/smx509/cert_pool_test.go @@ -0,0 +1,54 @@ +package smx509 + +import "testing" + +func TestCertPoolEqual(t *testing.T) { + a, b := NewCertPool(), NewCertPool() + if !a.Equal(b) { + t.Error("two empty pools not equal") + } + + tc := &Certificate{Raw: []byte{1, 2, 3}, RawSubject: []byte{2}} + a.AddCert(tc) + if a.Equal(b) { + t.Error("empty pool equals non-empty pool") + } + + b.AddCert(tc) + if !a.Equal(b) { + t.Error("two non-empty pools not equal") + } + + otherTC := &Certificate{Raw: []byte{9, 8, 7}, RawSubject: []byte{8}} + a.AddCert(otherTC) + if a.Equal(b) { + t.Error("non-equal pools equal") + } + + systemA, err := SystemCertPool() + if err != nil { + t.Fatalf("unable to load system cert pool: %s", err) + } + systemB, err := SystemCertPool() + if err != nil { + t.Fatalf("unable to load system cert pool: %s", err) + } + if !systemA.Equal(systemB) { + t.Error("two empty system pools not equal") + } + + systemA.AddCert(tc) + if systemA.Equal(systemB) { + t.Error("empty system pool equals non-empty system pool") + } + + systemB.AddCert(tc) + if !systemA.Equal(systemB) { + t.Error("two non-empty system pools not equal") + } + + systemA.AddCert(otherTC) + if systemA.Equal(systemB) { + t.Error("non-equal system pools equal") + } +}