From e7e60b35adb5508d4044fb7cac5c38960ef50b24 Mon Sep 17 00:00:00 2001 From: Todd Short Date: Tue, 15 Apr 2025 11:54:14 -0400 Subject: [PATCH] Change how systemroot is created when SSL_CERT environment is set The x509.SystemCertPool() looks at the SSL_CERT_FILE and SSL_CERT_DIR environment variables to generate the pool. However, if the contents of the referenced file (singular) or directories (multiple) change, there is no guarantee that x509.SystemCertPool() will be updated. Since we are watching these locations (defined by the environment) via fsnotify, we want to ensure that when those files are updated that the cert pool we use is also updated. So, if SSL_CERT_FILE or SSL_CERT_DIR are defined, create our cert pool from those variable _only_, ignoring the x509.SystemCertPool(). This is how the x509.SystemCertPool() would be created, so we do it explicitly instead. This allows us to properly refresh the pool when fsnotify tells us there are changes to our watches. This does not impact images/containers (i.e. impage pulling) directly, since that still uses x509.SystemCertPool(), so it may get a stale pool, but the catalogd client will have an up-to-date pool. See: https://pkg.go.dev/crypto/x509#SystemCertPool Signed-off-by: Todd Short --- .../shared/util/http/certpoolwatcher_test.go | 63 +++++++++ internal/shared/util/http/certutil.go | 120 ++++++++++++++---- 2 files changed, 158 insertions(+), 25 deletions(-) diff --git a/internal/shared/util/http/certpoolwatcher_test.go b/internal/shared/util/http/certpoolwatcher_test.go index ca13a478b..896a1dd8d 100644 --- a/internal/shared/util/http/certpoolwatcher_test.go +++ b/internal/shared/util/http/certpoolwatcher_test.go @@ -91,6 +91,7 @@ func TestCertPoolWatcher(t *testing.T) { t.Logf("Create cert file at %q\n", certName) createCert(t, certName) + // Eventually, the pools updates require.Eventually(t, func() bool { secondPool, secondGen, err := cpw.Get() if err != nil { @@ -99,3 +100,65 @@ func TestCertPoolWatcher(t *testing.T) { return secondGen != firstGen && !firstPool.Equal(secondPool) }, 30*time.Second, time.Second) } + +func TestCertPoolWatcherNoEnv(t *testing.T) { + // create a temporary directory + tmpDir, err := os.MkdirTemp("", "cert-pool") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // create the first cert + certName := filepath.Join(tmpDir, "test1.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + // Clear environment variables for the watcher + os.Setenv("SSL_CERT_DIR", "") + os.Setenv("SSL_CERT_FILE", "") + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background())) + require.NoError(t, err) + defer cpw.Done() + + // Get the original pool + firstPool, firstGen, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Create a second cert + certName = filepath.Join(tmpDir, "test2.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + // Eventually, the pool updates + require.Eventually(t, func() bool { + secondPool, secondGen, err := cpw.Get() + if err != nil { + return false + } + return secondGen != firstGen && !firstPool.Equal(secondPool) + }, 30*time.Second, time.Second) +} + +func TestCertPoolWatcherNoEnvEqualsSystem(t *testing.T) { + // Clear environment variables for the watcher + os.Setenv("SSL_CERT_DIR", "") + os.Setenv("SSL_CERT_FILE", "") + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background())) + require.NoError(t, err) + defer cpw.Done() + + // Get the original pool + firstPool, _, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Compare to the system pool + sysPool, err := x509.SystemCertPool() + require.NoError(t, err) + require.NotNil(t, sysPool) + require.True(t, firstPool.Equal(sysPool)) +} diff --git a/internal/shared/util/http/certutil.go b/internal/shared/util/http/certutil.go index fb7cdc4cb..c682416b1 100644 --- a/internal/shared/util/http/certutil.go +++ b/internal/shared/util/http/certutil.go @@ -5,50 +5,120 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/go-logr/logr" ) -func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { - caCertPool, err := x509.SystemCertPool() +func readCertFile(pool *x509.CertPool, file string, log logr.Logger) (bool, error) { + certRead := false + if file == "" { + return certRead, nil + } + // These might be symlinks pointing to directories, so use Stat() to resolve + fi, err := os.Stat(file) if err != nil { - return nil, err + // Ignore files that don't exist + if os.IsNotExist(err) { + return certRead, nil + } + return certRead, err } - if caDir == "" { - return caCertPool, nil + if fi.IsDir() { + log.V(defaultLogLevel).Info("skip directory", "name", file) + return certRead, nil } + log.V(defaultLogLevel).Info("load certificate", "name", file, "size", fi.Size(), "modtime", fi.ModTime()) + data, err := os.ReadFile(file) + if err != nil { + return certRead, fmt.Errorf("error reading cert file %q: %w", file, err) + } + // The return indicates if any certs were added + if pool.AppendCertsFromPEM(data) { + certRead = true + } + logPem(data, filepath.Base(file), filepath.Dir(file), "loading certificate file", log) + + return certRead, nil +} - dirEntries, err := os.ReadDir(caDir) +func readCertDir(pool *x509.CertPool, dir string, log logr.Logger) (bool, error) { + certRead := false + if dir == "" { + return certRead, nil + } + dirEntries, err := os.ReadDir(dir) if err != nil { - return nil, err + // Ignore directories that don't exist + if os.IsNotExist(err) { + return certRead, nil + } + return certRead, err } - count := 0 for _, e := range dirEntries { - file := filepath.Join(caDir, e.Name()) - // These might be symlinks pointing to directories, so use Stat() to resolve - fi, err := os.Stat(file) + file := filepath.Join(dir, e.Name()) + c, err := readCertFile(pool, file, log) if err != nil { - return nil, err - } - if fi.IsDir() { - log.V(defaultLogLevel).Info("skip directory", "name", e.Name()) - continue + return certRead, err } - log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime()) - data, err := os.ReadFile(file) + certRead = certRead || c + } + return certRead, nil +} + +// This function looks explicitly at the SSL environment, and +// uses it to create a "fresh" system cert pool +func systemCertPool(log logr.Logger) (*x509.CertPool, error) { + sslCertDir := os.Getenv("SSL_CERT_DIR") + sslCertFile := os.Getenv("SSL_CERT_FILE") + if sslCertDir == "" && sslCertFile == "" { + log.V(defaultLogLevel).Info("SystemCertPool: SSL environment not set") + return x509.SystemCertPool() + } + log.V(defaultLogLevel).Info("SystemCertPool: SSL environment set", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile) + + certRead := false + pool := x509.NewCertPool() + + // SSL_CERT_DIR may consist of multiple entries separated by ":" + for _, d := range strings.Split(sslCertDir, ":") { + c, err := readCertDir(pool, d, log) if err != nil { - return nil, fmt.Errorf("error reading cert file %q: %w", file, err) - } - // The return indicates if any certs were added - if caCertPool.AppendCertsFromPEM(data) { - count++ + return nil, err } - logPem(data, e.Name(), caDir, "loading certificate file", log) + certRead = certRead || c + } + // SSL_CERT_FILE may consist of only a single entry + c, err := readCertFile(pool, sslCertFile, log) + if err != nil { + return nil, err + } + certRead = certRead || c + + // If SSL_CERT_DIR and SSL_CERT_FILE resulted in no certs, then return the system cert pool + if !certRead { + return x509.SystemCertPool() + } + return pool, nil +} + +func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { + caCertPool, err := systemCertPool(log) + if err != nil { + return nil, err + } + + if caDir == "" { + return caCertPool, nil + } + readCert, err := readCertDir(caCertPool, caDir, log) + if err != nil { + return nil, err } // Found no certs! - if count == 0 { + if !readCert { return nil, fmt.Errorf("no certificates found in %q", caDir) }