From ea82b2f2dcd52635682ff6bfc120e33e1aeca4e5 Mon Sep 17 00:00:00 2001 From: Jan Roehrich Date: Tue, 13 Aug 2024 13:44:03 +0200 Subject: [PATCH] Fix race condition between service account availability and webhook invocation --- README.md | 63 +++++++++++++++++++++-------------------- main.go | 3 ++ pkg/cache/cache.go | 41 ++++++++++++++++++++------- pkg/cache/cache_test.go | 29 ++++++++++--------- pkg/cache/fake.go | 17 +++++++++-- pkg/handler/handler.go | 27 +++++++++++++++++- 6 files changed, 122 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index bb8d9e13e..24bcc519f 100644 --- a/README.md +++ b/README.md @@ -143,37 +143,38 @@ When running a container with a non-root user, you need to give the container ac ``` Usage of amazon-eks-pod-identity-webhook: - --add_dir_header If true, adds the file directory to the header - --alsologtostderr log to standard error as well as files - --annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com") - --aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers - --enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported - --in-cluster Use in-cluster authentication and certificate request API (default true) - --kube-api string (out-of-cluster) The url to the API server - --kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file - --log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0) - --log_dir string If non-empty, write log files in this directory - --log_file string If non-empty, use this log file - --log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800) - --logtostderr log to standard error instead of files (default true) - --metrics-port int Port to listen on for metrics (http) (default 9999) - --namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks") - --port int Port to listen on (default 443) - --service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook") - --skip_headers If true, avoid header prefixes in the log messages - --skip_log_headers If true, avoid headers when opening log files - --stderrthreshold severity logs at or above this threshold go to stderr (default 2) - --sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false. - --tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt") - --tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key") - --tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook") - --token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com") - --token-expiration int The token expiration (default 86400) - --token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount") - -v, --v Level number for the log level verbosity - --version Display the version and exit - --vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging - --watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations + --add_dir_header If true, adds the file directory to the header + --alsologtostderr log to standard error as well as files + --annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com") + --aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers + --enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported + --in-cluster Use in-cluster authentication and certificate request API (default true) + --kube-api string (out-of-cluster) The url to the API server + --kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file + --log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0) + --log_dir string If non-empty, write log files in this directory + --log_file string If non-empty, use this log file + --log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800) + --logtostderr log to standard error instead of files (default true) + --metrics-port int Port to listen on for metrics (http) (default 9999) + --namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks") + --port int Port to listen on (default 443) + --service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook") + --service-account-lookup-grace-period The grace period for service account to be available in cache before not mutating a pod. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance. (default 100ms) + --skip_headers If true, avoid header prefixes in the log messages + --skip_log_headers If true, avoid headers when opening log files + --stderrthreshold severity logs at or above this threshold go to stderr (default 2) + --sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false. + --tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt") + --tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key") + --tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook") + --token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com") + --token-expiration int The token expiration (default 86400) + --token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount") + -v, --v Level number for the log level verbosity + --version Display the version and exit + --vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging + --watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations ``` ### AWS_DEFAULT_REGION Injection diff --git a/main.go b/main.go index 22310b4be..6e77300ff 100644 --- a/main.go +++ b/main.go @@ -86,6 +86,8 @@ func main() { debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported") + saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 100*time.Millisecond, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 100ms. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance.") + klog.InitFlags(goflag.CommandLine) // Add klog CommandLine flags to pflag CommandLine goflag.CommandLine.VisitAll(func(f *goflag.Flag) { @@ -208,6 +210,7 @@ func main() { handler.WithServiceAccountCache(saCache), handler.WithContainerCredentialsConfig(containerCredentialsConfig), handler.WithRegion(*region), + handler.WithSALookupGraceTime(*saLookupGracePeriod), ) addr := fmt.Sprintf(":%d", *port) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e411fee8e..7787fafaa 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -42,7 +42,8 @@ type CacheResponse struct { type ServiceAccountCache interface { Start(stop chan struct{}) - Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) + Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) + GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) // ToJSON returns cache contents as JSON string ToJSON() string @@ -60,6 +61,7 @@ type serviceAccountCache struct { composeRoleArn ComposeRoleArn defaultTokenExpiration int64 webhookUsage prometheus.Gauge + notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified } type ComposeRoleArn struct { @@ -87,29 +89,37 @@ func init() { // Get will return the cached configuration of the given ServiceAccount. // It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any // ServiceAccount configured through the pod-identity-webhook ConfigMap. -func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + return c.GetOrNotify(name, namespace, nil) +} + +// GetOrNotify will return the cached configuration of the given ServiceAccount. +// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register +// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check +// for a ServiceAccount configured through the pod-identity-webhook ConfigMap. +func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name) { - resp := c.getSA(name, namespace) + resp := c.getSAorNotify(name, namespace, handler) if resp != nil && resp.RoleARN != "" { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } { resp := c.getCM(name, namespace) if resp != nil { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name) - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } // GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials). // The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility, // Use these fields if they are set in the sa annotations or config map. func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { - if resp := c.getSA(name, namespace); resp != nil { + if resp := c.getSAorNotify(name, namespace, nil); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration } else if resp := c.getCM(name, namespace); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration @@ -117,11 +127,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u return false, pkg.DefaultTokenExpiration } -func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse { +func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse { c.mu.RLock() defer c.mu.RUnlock() resp, ok := c.saCache[namespace+"/"+name] - if !ok { + if !ok && handler != nil { + klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name) + c.notificationHandlers[namespace+"/"+name] = handler return nil } return resp @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) { c.mu.Lock() defer c.mu.Unlock() - klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp) + + key := namespace + "/" + name + klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp) c.saCache[namespace+"/"+name] = resp + + if handler, found := c.notificationHandlers[key]; found { + klog.V(5).Infof("Notifying handler for %q", key) + handler <- 1 + delete(c.notificationHandlers, key) + } } func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) { @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx defaultTokenExpiration: defaultTokenExpiration, hasSynced: hasSynced, webhookUsage: webhookUsage, + notificationHandlers: map[string]chan any{}, } saInformer.Informer().AddEventHandler( diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 58c495cac..a8e0b50d8 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -36,16 +36,18 @@ func TestSaCache(t *testing.T) { webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), } - role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default") + role, aud, useRegionalSTS, tokenExpiration, found := cache.Get("default", "default") + assert.False(t, found, "Expected no cache entry to be found") if role != "" || aud != "" { t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration) } cache.addSA(testSA) - role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default") + role, aud, useRegionalSTS, tokenExpiration, found = cache.Get("default", "default") + assert.True(t, found, "Expected cache entry to be found") assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role) assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud) assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false") @@ -157,7 +159,8 @@ func TestNonRegionalSTS(t *testing.T) { t.Fatalf("cache never called addSA: %v", err) } - gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default") + gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default") + assert.True(t, found, "Expected cache entry to be found") if gotRoleArn != roleArn { t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn) } @@ -202,7 +205,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -214,7 +217,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -226,7 +229,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role != "" { t.Errorf("found entry that should have been removed") } @@ -256,7 +259,7 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(oldSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") + gotRoleArn, _, _, _, _ := c.Get("default", "default") if gotRoleArn != roleArn { t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn) } @@ -268,7 +271,7 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(newSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") + gotRoleArn, _, _, _, _ := c.Get("default", "default") if gotRoleArn != "" { t.Errorf("got roleArn %v, expected %q", gotRoleArn, "") } @@ -323,7 +326,7 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("could not find entry that should have been added") } @@ -340,7 +343,7 @@ func TestCachePrecedence(t *testing.T) { } // Removing sa2 from CM, but SA still exists - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("could not find entry that should still exist") } @@ -356,7 +359,7 @@ func TestCachePrecedence(t *testing.T) { c.addSA(sa2) // Neither cache should return any hits now - role, _, _, _ := c.Get("myns2", "mysa2") + role, _, _, _, _ := c.Get("myns2", "mysa2") if role != "" { t.Errorf("found entry that should not exist") } @@ -370,7 +373,7 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -422,7 +425,7 @@ func TestRoleArnComposition(t *testing.T) { var roleArn string err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) { - roleArn, _, _, _ = cache.Get("default", "default") + roleArn, _, _, _, _ = cache.Get("default", "default") return roleArn != "", nil }) if err != nil { diff --git a/pkg/cache/fake.go b/pkg/cache/fake.go index 0f5a67869..005bc375c 100644 --- a/pkg/cache/fake.go +++ b/pkg/cache/fake.go @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{} func (f *FakeServiceAccountCache) Start(chan struct{}) {} // Get gets a service account from the cache -func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { f.mu.RLock() defer f.mu.RUnlock() resp, ok := f.cache[namespace+"/"+name] if !ok { - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true +} + +// GetOrNotify gets a service account from the cache +func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + f.mu.RLock() + defer f.mu.RUnlock() + resp, ok := f.cache[namespace+"/"+name] + if !ok { + return "", "", false, pkg.DefaultTokenExpiration, false + } + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1425ca14b..a877cd017 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -24,6 +24,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" @@ -77,6 +78,12 @@ func WithAnnotationDomain(domain string) ModifierOpt { return func(m *Modifier) { m.AnnotationDomain = domain } } +// WithSALookupGraceTime sets the grace time to wait for service accounts to appear in cache +func WithSALookupGraceTime(saLookupGraceTime time.Duration) ModifierOpt { + return func(m *Modifier) { m.saLookupGraceTime = saLookupGraceTime } + +} + // NewModifier returns a Modifier with default values func NewModifier(opts ...ModifierOpt) *Modifier { mod := &Modifier{ @@ -101,6 +108,7 @@ type Modifier struct { ContainerCredentialsConfig containercredentials.Config volName string tokenName string + saLookupGraceTime time.Duration } type patchOperation struct { @@ -425,7 +433,24 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { } // Use the STS WebIdentity method if set - roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + handler := make(chan any, 1) + roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler) + key := pod.Namespace + "/" + pod.Spec.ServiceAccountName + if !found && m.saLookupGraceTime > 0 { + klog.Warningf("Service account %q not found in the cache. Waiting up to %s to be notified", key, m.saLookupGraceTime) + select { + case <-handler: + roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + if !found { + klog.Warningf("Service account %q not found in the cache after being notified. Not mutating.", key) + return nil + } + case <-time.After(m.saLookupGraceTime): + klog.Warningf("Service account %q not found in the cache after %s. Not mutating.", key, m.saLookupGraceTime) + return nil + } + } + klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %q: %s", key, roleArn) if roleArn != "" { tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)