diff --git a/internal/shared/util/http/certpoolwatcher_test.go b/internal/shared/util/http/certpoolwatcher_test.go index ca13a478b..896a1dd8d 100644 --- a/internal/shared/util/http/certpoolwatcher_test.go +++ b/internal/shared/util/http/certpoolwatcher_test.go @@ -91,6 +91,7 @@ func TestCertPoolWatcher(t *testing.T) { t.Logf("Create cert file at %q\n", certName) createCert(t, certName) + // Eventually, the pools updates require.Eventually(t, func() bool { secondPool, secondGen, err := cpw.Get() if err != nil { @@ -99,3 +100,65 @@ func TestCertPoolWatcher(t *testing.T) { return secondGen != firstGen && !firstPool.Equal(secondPool) }, 30*time.Second, time.Second) } + +func TestCertPoolWatcherNoEnv(t *testing.T) { + // create a temporary directory + tmpDir, err := os.MkdirTemp("", "cert-pool") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // create the first cert + certName := filepath.Join(tmpDir, "test1.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + // Clear environment variables for the watcher + os.Setenv("SSL_CERT_DIR", "") + os.Setenv("SSL_CERT_FILE", "") + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background())) + require.NoError(t, err) + defer cpw.Done() + + // Get the original pool + firstPool, firstGen, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Create a second cert + certName = filepath.Join(tmpDir, "test2.pem") + t.Logf("Create cert file at %q\n", certName) + createCert(t, certName) + + // Eventually, the pool updates + require.Eventually(t, func() bool { + secondPool, secondGen, err := cpw.Get() + if err != nil { + return false + } + return secondGen != firstGen && !firstPool.Equal(secondPool) + }, 30*time.Second, time.Second) +} + +func TestCertPoolWatcherNoEnvEqualsSystem(t *testing.T) { + // Clear environment variables for the watcher + os.Setenv("SSL_CERT_DIR", "") + os.Setenv("SSL_CERT_FILE", "") + + // Create the cert pool watcher + cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background())) + require.NoError(t, err) + defer cpw.Done() + + // Get the original pool + firstPool, _, err := cpw.Get() + require.NoError(t, err) + require.NotNil(t, firstPool) + + // Compare to the system pool + sysPool, err := x509.SystemCertPool() + require.NoError(t, err) + require.NotNil(t, sysPool) + require.True(t, firstPool.Equal(sysPool)) +} diff --git a/internal/shared/util/http/certutil.go b/internal/shared/util/http/certutil.go index fb7cdc4cb..c682416b1 100644 --- a/internal/shared/util/http/certutil.go +++ b/internal/shared/util/http/certutil.go @@ -5,50 +5,120 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/go-logr/logr" ) -func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { - caCertPool, err := x509.SystemCertPool() +func readCertFile(pool *x509.CertPool, file string, log logr.Logger) (bool, error) { + certRead := false + if file == "" { + return certRead, nil + } + // These might be symlinks pointing to directories, so use Stat() to resolve + fi, err := os.Stat(file) if err != nil { - return nil, err + // Ignore files that don't exist + if os.IsNotExist(err) { + return certRead, nil + } + return certRead, err } - if caDir == "" { - return caCertPool, nil + if fi.IsDir() { + log.V(defaultLogLevel).Info("skip directory", "name", file) + return certRead, nil } + log.V(defaultLogLevel).Info("load certificate", "name", file, "size", fi.Size(), "modtime", fi.ModTime()) + data, err := os.ReadFile(file) + if err != nil { + return certRead, fmt.Errorf("error reading cert file %q: %w", file, err) + } + // The return indicates if any certs were added + if pool.AppendCertsFromPEM(data) { + certRead = true + } + logPem(data, filepath.Base(file), filepath.Dir(file), "loading certificate file", log) + + return certRead, nil +} - dirEntries, err := os.ReadDir(caDir) +func readCertDir(pool *x509.CertPool, dir string, log logr.Logger) (bool, error) { + certRead := false + if dir == "" { + return certRead, nil + } + dirEntries, err := os.ReadDir(dir) if err != nil { - return nil, err + // Ignore directories that don't exist + if os.IsNotExist(err) { + return certRead, nil + } + return certRead, err } - count := 0 for _, e := range dirEntries { - file := filepath.Join(caDir, e.Name()) - // These might be symlinks pointing to directories, so use Stat() to resolve - fi, err := os.Stat(file) + file := filepath.Join(dir, e.Name()) + c, err := readCertFile(pool, file, log) if err != nil { - return nil, err - } - if fi.IsDir() { - log.V(defaultLogLevel).Info("skip directory", "name", e.Name()) - continue + return certRead, err } - log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime()) - data, err := os.ReadFile(file) + certRead = certRead || c + } + return certRead, nil +} + +// This function looks explicitly at the SSL environment, and +// uses it to create a "fresh" system cert pool +func systemCertPool(log logr.Logger) (*x509.CertPool, error) { + sslCertDir := os.Getenv("SSL_CERT_DIR") + sslCertFile := os.Getenv("SSL_CERT_FILE") + if sslCertDir == "" && sslCertFile == "" { + log.V(defaultLogLevel).Info("SystemCertPool: SSL environment not set") + return x509.SystemCertPool() + } + log.V(defaultLogLevel).Info("SystemCertPool: SSL environment set", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile) + + certRead := false + pool := x509.NewCertPool() + + // SSL_CERT_DIR may consist of multiple entries separated by ":" + for _, d := range strings.Split(sslCertDir, ":") { + c, err := readCertDir(pool, d, log) if err != nil { - return nil, fmt.Errorf("error reading cert file %q: %w", file, err) - } - // The return indicates if any certs were added - if caCertPool.AppendCertsFromPEM(data) { - count++ + return nil, err } - logPem(data, e.Name(), caDir, "loading certificate file", log) + certRead = certRead || c + } + // SSL_CERT_FILE may consist of only a single entry + c, err := readCertFile(pool, sslCertFile, log) + if err != nil { + return nil, err + } + certRead = certRead || c + + // If SSL_CERT_DIR and SSL_CERT_FILE resulted in no certs, then return the system cert pool + if !certRead { + return x509.SystemCertPool() + } + return pool, nil +} + +func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) { + caCertPool, err := systemCertPool(log) + if err != nil { + return nil, err + } + + if caDir == "" { + return caCertPool, nil + } + readCert, err := readCertDir(caCertPool, caDir, log) + if err != nil { + return nil, err } // Found no certs! - if count == 0 { + if !readCert { return nil, fmt.Errorf("no certificates found in %q", caDir) }