From 37ac32fe8939bdbfa7fa91557c9f4820687c36e6 Mon Sep 17 00:00:00 2001 From: Jan Roehrich Date: Tue, 13 Aug 2024 13:44:03 +0200 Subject: [PATCH] Fix race condition between service account availability and webhook invocation --- pkg/cache/cache.go | 43 +++++++++++++++++++++++++++++++----------- pkg/cache/fake.go | 17 ++++++++++++++--- pkg/handler/handler.go | 19 ++++++++++++++++++- 3 files changed, 64 insertions(+), 15 deletions(-) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e411fee8e..2d469d3ab 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -42,7 +42,8 @@ type CacheResponse struct { type ServiceAccountCache interface { Start(stop chan struct{}) - Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) + Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) + GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) // ToJSON returns cache contents as JSON string ToJSON() string @@ -60,6 +61,7 @@ type serviceAccountCache struct { composeRoleArn ComposeRoleArn defaultTokenExpiration int64 webhookUsage prometheus.Gauge + notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified } type ComposeRoleArn struct { @@ -87,29 +89,37 @@ func init() { // Get will return the cached configuration of the given ServiceAccount. // It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any // ServiceAccount configured through the pod-identity-webhook ConfigMap. -func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + return c.GetOrNotify(name, namespace, nil) +} + +// GetOrNotify will return the cached configuration of the given ServiceAccount. +// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register +// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check +// for a ServiceAccount configured through the pod-identity-webhook ConfigMap. +func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name) { - resp := c.getSA(name, namespace) - if resp != nil && resp.RoleARN != "" { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + resp := c.getSAorNotify(name, namespace, handler) + if resp != nil { + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } { resp := c.getCM(name, namespace) if resp != nil { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name) - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } // GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials). // The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility, // Use these fields if they are set in the sa annotations or config map. func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { - if resp := c.getSA(name, namespace); resp != nil { + if resp := c.getSAorNotify(name, namespace, nil); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration } else if resp := c.getCM(name, namespace); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration @@ -117,11 +127,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u return false, pkg.DefaultTokenExpiration } -func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse { +func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse { c.mu.RLock() defer c.mu.RUnlock() resp, ok := c.saCache[namespace+"/"+name] - if !ok { + if !ok && handler != nil { + klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name) + c.notificationHandlers[namespace+"/"+name] = handler return nil } return resp @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) { c.mu.Lock() defer c.mu.Unlock() - klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp) + + key := namespace + "/" + name + klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp) c.saCache[namespace+"/"+name] = resp + + if handler, found := c.notificationHandlers[key]; found { + klog.V(5).Infof("Notifying handler for %q", key) + handler <- 1 + delete(c.notificationHandlers, key) + } } func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) { @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx defaultTokenExpiration: defaultTokenExpiration, hasSynced: hasSynced, webhookUsage: webhookUsage, + notificationHandlers: map[string]chan any{}, } saInformer.Informer().AddEventHandler( diff --git a/pkg/cache/fake.go b/pkg/cache/fake.go index 0f5a67869..005bc375c 100644 --- a/pkg/cache/fake.go +++ b/pkg/cache/fake.go @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{} func (f *FakeServiceAccountCache) Start(chan struct{}) {} // Get gets a service account from the cache -func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { f.mu.RLock() defer f.mu.RUnlock() resp, ok := f.cache[namespace+"/"+name] if !ok { - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true +} + +// GetOrNotify gets a service account from the cache +func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + f.mu.RLock() + defer f.mu.RUnlock() + resp, ok := f.cache[namespace+"/"+name] + if !ok { + return "", "", false, pkg.DefaultTokenExpiration, false + } + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1425ca14b..05252b6d4 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -24,6 +24,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" @@ -425,7 +426,23 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { } // Use the STS WebIdentity method if set - roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + handler := make(chan any, 1) + roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler) + if !found { + klog.Warningf("Service account %s/%s not found in the cache. Waiting up to 5s to be notified", pod.Namespace, pod.Spec.ServiceAccountName) + select { + case <-handler: + roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + if !found { + klog.Warningf("Service account %s/%s not found in the cache after being notified. Not mutating.", pod.Namespace, pod.Spec.ServiceAccountName) + return nil + } + case <-time.After(5 * time.Second): + klog.Warningf("Service account %s/%s not found in the cache after 5s. Not mutating.", pod.Namespace, pod.Spec.ServiceAccountName) + return nil + } + } + klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %s/%s: %s", pod.Namespace, pod.Spec.ServiceAccountName, roleArn) if roleArn != "" { tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)