Skip to content

Commit 0750d17

Browse files
committed
Change how systemroot is created
Signed-off-by: Todd Short <[email protected]>
1 parent 543f099 commit 0750d17

File tree

2 files changed

+158
-25
lines changed

2 files changed

+158
-25
lines changed

internal/shared/util/http/certpoolwatcher_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func TestCertPoolWatcher(t *testing.T) {
9191
t.Logf("Create cert file at %q\n", certName)
9292
createCert(t, certName)
9393

94+
// Eventually, the pools updates
9495
require.Eventually(t, func() bool {
9596
secondPool, secondGen, err := cpw.Get()
9697
if err != nil {
@@ -99,3 +100,65 @@ func TestCertPoolWatcher(t *testing.T) {
99100
return secondGen != firstGen && !firstPool.Equal(secondPool)
100101
}, 30*time.Second, time.Second)
101102
}
103+
104+
func TestCertPoolWatcherNoEnv(t *testing.T) {
105+
// create a temporary directory
106+
tmpDir, err := os.MkdirTemp("", "cert-pool")
107+
require.NoError(t, err)
108+
defer os.RemoveAll(tmpDir)
109+
110+
// create the first cert
111+
certName := filepath.Join(tmpDir, "test1.pem")
112+
t.Logf("Create cert file at %q\n", certName)
113+
createCert(t, certName)
114+
115+
// Clear environment variables for the watcher
116+
os.Setenv("SSL_CERT_DIR", "")
117+
os.Setenv("SSL_CERT_FILE", "")
118+
119+
// Create the cert pool watcher
120+
cpw, err := httputil.NewCertPoolWatcher(tmpDir, log.FromContext(context.Background()))
121+
require.NoError(t, err)
122+
defer cpw.Done()
123+
124+
// Get the original pool
125+
firstPool, firstGen, err := cpw.Get()
126+
require.NoError(t, err)
127+
require.NotNil(t, firstPool)
128+
129+
// Create a second cert
130+
certName = filepath.Join(tmpDir, "test2.pem")
131+
t.Logf("Create cert file at %q\n", certName)
132+
createCert(t, certName)
133+
134+
// Eventually, the pool updates
135+
require.Eventually(t, func() bool {
136+
secondPool, secondGen, err := cpw.Get()
137+
if err != nil {
138+
return false
139+
}
140+
return secondGen != firstGen && !firstPool.Equal(secondPool)
141+
}, 30*time.Second, time.Second)
142+
}
143+
144+
func TestCertPoolWatcherNoEnvEqualsSystem(t *testing.T) {
145+
// Clear environment variables for the watcher
146+
os.Setenv("SSL_CERT_DIR", "")
147+
os.Setenv("SSL_CERT_FILE", "")
148+
149+
// Create the cert pool watcher
150+
cpw, err := httputil.NewCertPoolWatcher("", log.FromContext(context.Background()))
151+
require.NoError(t, err)
152+
defer cpw.Done()
153+
154+
// Get the original pool
155+
firstPool, _, err := cpw.Get()
156+
require.NoError(t, err)
157+
require.NotNil(t, firstPool)
158+
159+
// Compare to the system pool
160+
sysPool, err := x509.SystemCertPool()
161+
require.NoError(t, err)
162+
require.NotNil(t, sysPool)
163+
require.True(t, firstPool.Equal(sysPool))
164+
}

internal/shared/util/http/certutil.go

+95-25
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,120 @@ import (
55
"fmt"
66
"os"
77
"path/filepath"
8+
"strings"
89

910
"github.com/go-logr/logr"
1011
)
1112

12-
func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
13-
caCertPool, err := x509.SystemCertPool()
13+
func readCertFile(pool *x509.CertPool, file string, log logr.Logger) (bool, error) {
14+
var certRead bool
15+
if file == "" {
16+
return certRead, nil
17+
}
18+
// These might be symlinks pointing to directories, so use Stat() to resolve
19+
fi, err := os.Stat(file)
1420
if err != nil {
15-
return nil, err
21+
// Ignore files that don't exist
22+
if os.IsNotExist(err) {
23+
return certRead, nil
24+
}
25+
return certRead, err
1626
}
17-
if caDir == "" {
18-
return caCertPool, nil
27+
if fi.IsDir() {
28+
log.V(defaultLogLevel).Info("skip directory", "name", file)
29+
return certRead, nil
1930
}
31+
log.V(defaultLogLevel).Info("load certificate", "name", file, "size", fi.Size(), "modtime", fi.ModTime())
32+
data, err := os.ReadFile(file)
33+
if err != nil {
34+
return certRead, fmt.Errorf("error reading cert file %q: %w", file, err)
35+
}
36+
// The return indicates if any certs were added
37+
if pool.AppendCertsFromPEM(data) {
38+
certRead = true
39+
}
40+
logPem(data, filepath.Base(file), filepath.Dir(file), "loading certificate file", log)
41+
42+
return certRead, nil
43+
}
2044

21-
dirEntries, err := os.ReadDir(caDir)
45+
func readCertDir(pool *x509.CertPool, dir string, log logr.Logger) (bool, error) {
46+
var certRead bool
47+
if dir == "" {
48+
return certRead, nil
49+
}
50+
dirEntries, err := os.ReadDir(dir)
2251
if err != nil {
23-
return nil, err
52+
// Ignore directories that don't exist
53+
if os.IsNotExist(err) {
54+
return certRead, nil
55+
}
56+
return certRead, err
2457
}
25-
count := 0
2658

2759
for _, e := range dirEntries {
28-
file := filepath.Join(caDir, e.Name())
29-
// These might be symlinks pointing to directories, so use Stat() to resolve
30-
fi, err := os.Stat(file)
60+
file := filepath.Join(dir, e.Name())
61+
c, err := readCertFile(pool, file, log)
3162
if err != nil {
32-
return nil, err
33-
}
34-
if fi.IsDir() {
35-
log.V(defaultLogLevel).Info("skip directory", "name", e.Name())
36-
continue
63+
return certRead, err
3764
}
38-
log.V(defaultLogLevel).Info("load certificate", "name", e.Name(), "size", fi.Size(), "modtime", fi.ModTime())
39-
data, err := os.ReadFile(file)
65+
certRead = certRead || c
66+
}
67+
return certRead, nil
68+
}
69+
70+
// This function looks explicitly at the SSL environment, and
71+
// uses it to create a "fresh" system cert pool
72+
func systemCertPool(log logr.Logger) (*x509.CertPool, error) {
73+
sslCertDir := os.Getenv("SSL_CERT_DIR")
74+
sslCertFile := os.Getenv("SSL_CERT_FILE")
75+
if sslCertDir == "" && sslCertFile == "" {
76+
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment not set")
77+
return x509.SystemCertPool()
78+
}
79+
log.V(defaultLogLevel).Info("SystemCertPool: SSL environment set", "SSL_CERT_DIR", sslCertDir, "SSL_CERT_FILE", sslCertFile)
80+
81+
var certRead bool
82+
pool := x509.NewCertPool()
83+
84+
// SSL_CERT_DIR may consist of multiple entries separated by ":"
85+
for _, d := range strings.Split(sslCertDir, ":") {
86+
c, err := readCertDir(pool, d, log)
4087
if err != nil {
41-
return nil, fmt.Errorf("error reading cert file %q: %w", file, err)
42-
}
43-
// The return indicates if any certs were added
44-
if caCertPool.AppendCertsFromPEM(data) {
45-
count++
88+
return nil, err
4689
}
47-
logPem(data, e.Name(), caDir, "loading certificate file", log)
90+
certRead = certRead || c
91+
}
92+
// SSL_CERT_FILE may consist of only a single entry
93+
c, err := readCertFile(pool, sslCertFile, log)
94+
if err != nil {
95+
return nil, err
96+
}
97+
certRead = certRead || c
98+
99+
// If SSL_CERT_DIR and SSL_CERT_FILE resulted in no certs, then return the system cert pool
100+
if !certRead {
101+
return x509.SystemCertPool()
102+
}
103+
return pool, nil
104+
}
105+
106+
func NewCertPool(caDir string, log logr.Logger) (*x509.CertPool, error) {
107+
caCertPool, err := systemCertPool(log)
108+
if err != nil {
109+
return nil, err
110+
}
111+
112+
if caDir == "" {
113+
return caCertPool, nil
114+
}
115+
count, err := readCertDir(caCertPool, caDir, log)
116+
if err != nil {
117+
return nil, err
48118
}
49119

50120
// Found no certs!
51-
if count == 0 {
121+
if !count {
52122
return nil, fmt.Errorf("no certificates found in %q", caDir)
53123
}
54124

0 commit comments

Comments
 (0)