Skip to content

🌱 Change how systemroot is created when SSL_CERT environment is set #1921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions internal/shared/util/http/certpoolwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func TestCertPoolWatcher(t *testing.T) {
t.Logf("Create cert file at %q\n", certName)
createCert(t, certName)

// Eventually, the pools updates
require.Eventually(t, func() bool {
secondPool, secondGen, err := cpw.Get()
if err != nil {
Expand All @@ -99,3 +100,65 @@ func TestCertPoolWatcher(t *testing.T) {
return secondGen != firstGen && !firstPool.Equal(secondPool)
}, 30*time.Second, time.Second)
}

func TestCertPoolWatcherNoEnv(t *testing.T) {
// create a temporary directory
tmpDir, err := os.MkdirTemp("", "cert-pool")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

// create the first cert
certName := filepath.Join(tmpDir, "test1.pem")
t.Logf("Create cert file at %q\n", certName)
createCert(t, certName)

// Clear environment variables for the watcher
os.Setenv("SSL_CERT_DIR", "")
os.Setenv("SSL_CERT_FILE", "")

// Create the cert pool watcher
cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background()))
require.NoError(t, err)
defer cpw.Done()

// Get the original pool
firstPool, firstGen, err := cpw.Get()
require.NoError(t, err)
require.NotNil(t, firstPool)

// Create a second cert
certName = filepath.Join(tmpDir, "test2.pem")
t.Logf("Create cert file at %q\n", certName)
createCert(t, certName)

// Eventually, the pool updates
require.Eventually(t, func() bool {
secondPool, secondGen, err := cpw.Get()
if err != nil {
return false
}
return secondGen != firstGen && !firstPool.Equal(secondPool)
}, 30*time.Second, time.Second)
}

func TestCertPoolWatcherNoEnvEqualsSystem(t *testing.T) {
// Clear environment variables for the watcher
os.Setenv("SSL_CERT_DIR", "")
os.Setenv("SSL_CERT_FILE", "")

// Create the cert pool watcher
cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background()))
require.NoError(t, err)
defer cpw.Done()

// Get the original pool
firstPool, _, err := cpw.Get()
require.NoError(t, err)
require.NotNil(t, firstPool)

// Compare to the system pool
sysPool, err := x509.SystemCertPool()
require.NoError(t, err)
require.NotNil(t, sysPool)
require.True(t, firstPool.Equal(sysPool))
}
120 changes: 95 additions & 25 deletions internal/shared/util/http/certutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,120 @@ import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/go-logr/logr"
)

func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
caCertPool, err := x509.SystemCertPool()
func readCertFile(pool *x509.CertPool, file string, log logr.Logger) (bool, error) {
certRead := false
if file == "" {
return certRead, nil
}
// These might be symlinks pointing to directories, so use Stat() to resolve
fi, err := os.Stat(file)
if err != nil {
return nil, err
// Ignore files that don't exist
if os.IsNotExist(err) {
return certRead, nil
}
return certRead, err
}
if caDir == "" {
return caCertPool, nil
if fi.IsDir() {
log.V(defaultLogLevel).Info("skip directory", "name", file)
return certRead, nil
}
log.V(defaultLogLevel).Info("load certificate", "name", file, "size", fi.Size(), "modtime", fi.ModTime())
data, err := os.ReadFile(file)
if err != nil {
return certRead, fmt.Errorf("error reading cert file %q: %w", file, err)
}
// The return indicates if any certs were added
if pool.AppendCertsFromPEM(data) {
certRead = true
}
logPem(data, filepath.Base(file), filepath.Dir(file), "loading certificate file", log)

return certRead, nil
}

dirEntries, err := os.ReadDir(caDir)
func readCertDir(pool *x509.CertPool, dir string, log logr.Logger) (bool, error) {
certRead := false
if dir == "" {
return certRead, nil
}
dirEntries, err := os.ReadDir(dir)
if err != nil {
return nil, err
// Ignore directories that don't exist
if os.IsNotExist(err) {
return certRead, nil
}
return certRead, err
}
count := 0

for _, e := range dirEntries {
file := filepath.Join(caDir, e.Name())
// These might be symlinks pointing to directories, so use Stat() to resolve
fi, err := os.Stat(file)
file := filepath.Join(dir, e.Name())
c, err := readCertFile(pool, file, log)
if err != nil {
return nil, err
}
if fi.IsDir() {
log.V(defaultLogLevel).Info("skip directory", "name", e.Name())
continue
return certRead, err
}
log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime())
data, err := os.ReadFile(file)
certRead = certRead || c
}
return certRead, nil
}

// This function looks explicitly at the SSL environment, and
// uses it to create a "fresh" system cert pool
func systemCertPool(log logr.Logger) (*x509.CertPool, error) {
sslCertDir := os.Getenv("SSL_CERT_DIR")
sslCertFile := os.Getenv("SSL_CERT_FILE")
if sslCertDir == "" && sslCertFile == "" {
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment not set")
return x509.SystemCertPool()
}
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment set", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile)

certRead := false
pool := x509.NewCertPool()

// SSL_CERT_DIR may consist of multiple entries separated by ":"
for _, d := range strings.Split(sslCertDir, ":") {
c, err := readCertDir(pool, d, log)
if err != nil {
return nil, fmt.Errorf("error reading cert file %q: %w", file, err)
}
// The return indicates if any certs were added
if caCertPool.AppendCertsFromPEM(data) {
count++
return nil, err
}
logPem(data, e.Name(), caDir, "loading certificate file", log)
certRead = certRead || c
}
// SSL_CERT_FILE may consist of only a single entry
c, err := readCertFile(pool, sslCertFile, log)
if err != nil {
return nil, err
}
certRead = certRead || c

// If SSL_CERT_DIR and SSL_CERT_FILE resulted in no certs, then return the system cert pool
if !certRead {
return x509.SystemCertPool()
}
return pool, nil
}

func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
caCertPool, err := systemCertPool(log)
if err != nil {
return nil, err
}

if caDir == "" {
return caCertPool, nil
}
readCert, err := readCertDir(caCertPool, caDir, log)
if err != nil {
return nil, err
}

// Found no certs!
if count == 0 {
if !readCert {
return nil, fmt.Errorf("no certificates found in %q", caDir)
}

Expand Down
Loading