Skip to content

Commit e7e60b3

Browse files
committed
Change how systemroot is created when SSL_CERT environment is set
The x509.SystemCertPool() looks at the SSL_CERT_FILE and SSL_CERT_DIR environment variables to generate the pool. However, if the contents of the referenced file (singular) or directories (multiple) change, there is no guarantee that x509.SystemCertPool() will be updated. Since we are watching these locations (defined by the environment) via fsnotify, we want to ensure that when those files are updated that the cert pool we use is also updated. So, if SSL_CERT_FILE or SSL_CERT_DIR are defined, create our cert pool from those variable _only_, ignoring the x509.SystemCertPool(). This is how the x509.SystemCertPool() would be created, so we do it explicitly instead. This allows us to properly refresh the pool when fsnotify tells us there are changes to our watches. This does not impact images/containers (i.e. impage pulling) directly, since that still uses x509.SystemCertPool(), so it may get a stale pool, but the catalogd client will have an up-to-date pool. See: https://pkg.go.dev/crypto/x509#SystemCertPool Signed-off-by: Todd Short <[email protected]>
1 parent 543f099 commit e7e60b3

File tree

2 files changed

+158
-25
lines changed

2 files changed

+158
-25
lines changed

internal/shared/util/http/certpoolwatcher_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func TestCertPoolWatcher(t *testing.T) {
9191
t.Logf("Create cert file at %q\n", certName)
9292
createCert(t, certName)
9393

94+
// Eventually, the pools updates
9495
require.Eventually(t, func() bool {
9596
secondPool, secondGen, err := cpw.Get()
9697
if err != nil {
@@ -99,3 +100,65 @@ func TestCertPoolWatcher(t *testing.T) {
99100
return secondGen != firstGen && !firstPool.Equal(secondPool)
100101
}, 30*time.Second, time.Second)
101102
}
103+
104+
func TestCertPoolWatcherNoEnv(t *testing.T) {
105+
// create a temporary directory
106+
tmpDir, err := os.MkdirTemp("", "cert-pool")
107+
require.NoError(t, err)
108+
defer os.RemoveAll(tmpDir)
109+
110+
// create the first cert
111+
certName := filepath.Join(tmpDir, "test1.pem")
112+
t.Logf("Create cert file at %q\n", certName)
113+
createCert(t, certName)
114+
115+
// Clear environment variables for the watcher
116+
os.Setenv("SSL_CERT_DIR", "")
117+
os.Setenv("SSL_CERT_FILE", "")
118+
119+
// Create the cert pool watcher
120+
cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background()))
121+
require.NoError(t, err)
122+
defer cpw.Done()
123+
124+
// Get the original pool
125+
firstPool, firstGen, err := cpw.Get()
126+
require.NoError(t, err)
127+
require.NotNil(t, firstPool)
128+
129+
// Create a second cert
130+
certName = filepath.Join(tmpDir, "test2.pem")
131+
t.Logf("Create cert file at %q\n", certName)
132+
createCert(t, certName)
133+
134+
// Eventually, the pool updates
135+
require.Eventually(t, func() bool {
136+
secondPool, secondGen, err := cpw.Get()
137+
if err != nil {
138+
return false
139+
}
140+
return secondGen != firstGen && !firstPool.Equal(secondPool)
141+
}, 30*time.Second, time.Second)
142+
}
143+
144+
func TestCertPoolWatcherNoEnvEqualsSystem(t *testing.T) {
145+
// Clear environment variables for the watcher
146+
os.Setenv("SSL_CERT_DIR", "")
147+
os.Setenv("SSL_CERT_FILE", "")
148+
149+
// Create the cert pool watcher
150+
cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background()))
151+
require.NoError(t, err)
152+
defer cpw.Done()
153+
154+
// Get the original pool
155+
firstPool, _, err := cpw.Get()
156+
require.NoError(t, err)
157+
require.NotNil(t, firstPool)
158+
159+
// Compare to the system pool
160+
sysPool, err := x509.SystemCertPool()
161+
require.NoError(t, err)
162+
require.NotNil(t, sysPool)
163+
require.True(t, firstPool.Equal(sysPool))
164+
}

internal/shared/util/http/certutil.go

+95-25
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,120 @@ import (
55
"fmt"
66
"os"
77
"path/filepath"
8+
"strings"
89

910
"github.com/go-logr/logr"
1011
)
1112

12-
func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
13-
caCertPool, err := x509.SystemCertPool()
13+
func readCertFile(pool *x509.CertPool, file string, log logr.Logger) (bool, error) {
14+
certRead := false
15+
if file == "" {
16+
return certRead, nil
17+
}
18+
// These might be symlinks pointing to directories, so use Stat() to resolve
19+
fi, err := os.Stat(file)
1420
if err != nil {
15-
return nil, err
21+
// Ignore files that don't exist
22+
if os.IsNotExist(err) {
23+
return certRead, nil
24+
}
25+
return certRead, err
1626
}
17-
if caDir == "" {
18-
return caCertPool, nil
27+
if fi.IsDir() {
28+
log.V(defaultLogLevel).Info("skip directory", "name", file)
29+
return certRead, nil
1930
}
31+
log.V(defaultLogLevel).Info("load certificate", "name", file, "size", fi.Size(), "modtime", fi.ModTime())
32+
data, err := os.ReadFile(file)
33+
if err != nil {
34+
return certRead, fmt.Errorf("error reading cert file %q: %w", file, err)
35+
}
36+
// The return indicates if any certs were added
37+
if pool.AppendCertsFromPEM(data) {
38+
certRead = true
39+
}
40+
logPem(data, filepath.Base(file), filepath.Dir(file), "loading certificate file", log)
41+
42+
return certRead, nil
43+
}
2044

21-
dirEntries, err := os.ReadDir(caDir)
45+
func readCertDir(pool *x509.CertPool, dir string, log logr.Logger) (bool, error) {
46+
certRead := false
47+
if dir == "" {
48+
return certRead, nil
49+
}
50+
dirEntries, err := os.ReadDir(dir)
2251
if err != nil {
23-
return nil, err
52+
// Ignore directories that don't exist
53+
if os.IsNotExist(err) {
54+
return certRead, nil
55+
}
56+
return certRead, err
2457
}
25-
count := 0
2658

2759
for _, e := range dirEntries {
28-
file := filepath.Join(caDir, e.Name())
29-
// These might be symlinks pointing to directories, so use Stat() to resolve
30-
fi, err := os.Stat(file)
60+
file := filepath.Join(dir, e.Name())
61+
c, err := readCertFile(pool, file, log)
3162
if err != nil {
32-
return nil, err
33-
}
34-
if fi.IsDir() {
35-
log.V(defaultLogLevel).Info("skip directory", "name", e.Name())
36-
continue
63+
return certRead, err
3764
}
38-
log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime())
39-
data, err := os.ReadFile(file)
65+
certRead = certRead || c
66+
}
67+
return certRead, nil
68+
}
69+
70+
// This function looks explicitly at the SSL environment, and
71+
// uses it to create a "fresh" system cert pool
72+
func systemCertPool(log logr.Logger) (*x509.CertPool, error) {
73+
sslCertDir := os.Getenv("SSL_CERT_DIR")
74+
sslCertFile := os.Getenv("SSL_CERT_FILE")
75+
if sslCertDir == "" && sslCertFile == "" {
76+
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment not set")
77+
return x509.SystemCertPool()
78+
}
79+
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment set", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile)
80+
81+
certRead := false
82+
pool := x509.NewCertPool()
83+
84+
// SSL_CERT_DIR may consist of multiple entries separated by ":"
85+
for _, d := range strings.Split(sslCertDir, ":") {
86+
c, err := readCertDir(pool, d, log)
4087
if err != nil {
41-
return nil, fmt.Errorf("error reading cert file %q: %w", file, err)
42-
}
43-
// The return indicates if any certs were added
44-
if caCertPool.AppendCertsFromPEM(data) {
45-
count++
88+
return nil, err
4689
}
47-
logPem(data, e.Name(), caDir, "loading certificate file", log)
90+
certRead = certRead || c
91+
}
92+
// SSL_CERT_FILE may consist of only a single entry
93+
c, err := readCertFile(pool, sslCertFile, log)
94+
if err != nil {
95+
return nil, err
96+
}
97+
certRead = certRead || c
98+
99+
// If SSL_CERT_DIR and SSL_CERT_FILE resulted in no certs, then return the system cert pool
100+
if !certRead {
101+
return x509.SystemCertPool()
102+
}
103+
return pool, nil
104+
}
105+
106+
func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
107+
caCertPool, err := systemCertPool(log)
108+
if err != nil {
109+
return nil, err
110+
}
111+
112+
if caDir == "" {
113+
return caCertPool, nil
114+
}
115+
readCert, err := readCertDir(caCertPool, caDir, log)
116+
if err != nil {
117+
return nil, err
48118
}
49119

50120
// Found no certs!
51-
if count == 0 {
121+
if !readCert {
52122
return nil, fmt.Errorf("no certificates found in %q", caDir)
53123
}
54124

0 commit comments

Comments
 (0)