diff --git a/main.go b/main.go index 22310b4be..6e77300ff 100644 --- a/main.go +++ b/main.go @@ -86,6 +86,8 @@ func main() { debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported") + saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 100*time.Millisecond, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 100ms. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance.") + klog.InitFlags(goflag.CommandLine) // Add klog CommandLine flags to pflag CommandLine goflag.CommandLine.VisitAll(func(f *goflag.Flag) { @@ -208,6 +210,7 @@ func main() { handler.WithServiceAccountCache(saCache), handler.WithContainerCredentialsConfig(containerCredentialsConfig), handler.WithRegion(*region), + handler.WithSALookupGraceTime(*saLookupGracePeriod), ) addr := fmt.Sprintf(":%d", *port) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e411fee8e..7787fafaa 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -42,7 +42,8 @@ type CacheResponse struct { type ServiceAccountCache interface { Start(stop chan struct{}) - Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) + Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) + GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) // ToJSON returns cache contents as JSON string ToJSON() string @@ -60,6 +61,7 @@ type serviceAccountCache struct { composeRoleArn ComposeRoleArn defaultTokenExpiration int64 webhookUsage prometheus.Gauge + notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified } type ComposeRoleArn struct { @@ -87,29 +89,37 @@ func init() { // Get will return the cached configuration of the given ServiceAccount. // It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any // ServiceAccount configured through the pod-identity-webhook ConfigMap. -func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + return c.GetOrNotify(name, namespace, nil) +} + +// GetOrNotify will return the cached configuration of the given ServiceAccount. +// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register +// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check +// for a ServiceAccount configured through the pod-identity-webhook ConfigMap. +func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name) { - resp := c.getSA(name, namespace) + resp := c.getSAorNotify(name, namespace, handler) if resp != nil && resp.RoleARN != "" { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } { resp := c.getCM(name, namespace) if resp != nil { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } } klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name) - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } // GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials). // The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility, // Use these fields if they are set in the sa annotations or config map. func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { - if resp := c.getSA(name, namespace); resp != nil { + if resp := c.getSAorNotify(name, namespace, nil); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration } else if resp := c.getCM(name, namespace); resp != nil { return resp.UseRegionalSTS, resp.TokenExpiration @@ -117,11 +127,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u return false, pkg.DefaultTokenExpiration } -func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse { +func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse { c.mu.RLock() defer c.mu.RUnlock() resp, ok := c.saCache[namespace+"/"+name] - if !ok { + if !ok && handler != nil { + klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name) + c.notificationHandlers[namespace+"/"+name] = handler return nil } return resp @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) { c.mu.Lock() defer c.mu.Unlock() - klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp) + + key := namespace + "/" + name + klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp) c.saCache[namespace+"/"+name] = resp + + if handler, found := c.notificationHandlers[key]; found { + klog.V(5).Infof("Notifying handler for %q", key) + handler <- 1 + delete(c.notificationHandlers, key) + } } func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) { @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx defaultTokenExpiration: defaultTokenExpiration, hasSynced: hasSynced, webhookUsage: webhookUsage, + notificationHandlers: map[string]chan any{}, } saInformer.Informer().AddEventHandler( diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 58c495cac..a8e0b50d8 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -36,16 +36,18 @@ func TestSaCache(t *testing.T) { webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), } - role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default") + role, aud, useRegionalSTS, tokenExpiration, found := cache.Get("default", "default") + assert.False(t, found, "Expected no cache entry to be found") if role != "" || aud != "" { t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration) } cache.addSA(testSA) - role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default") + role, aud, useRegionalSTS, tokenExpiration, found = cache.Get("default", "default") + assert.True(t, found, "Expected cache entry to be found") assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role) assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud) assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false") @@ -157,7 +159,8 @@ func TestNonRegionalSTS(t *testing.T) { t.Fatalf("cache never called addSA: %v", err) } - gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default") + gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default") + assert.True(t, found, "Expected cache entry to be found") if gotRoleArn != roleArn { t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn) } @@ -202,7 +205,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -214,7 +217,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -226,7 +229,7 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") + role, _, _, _, _ := c.Get("mysa2", "myns2") if role != "" { t.Errorf("found entry that should have been removed") } @@ -256,7 +259,7 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(oldSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") + gotRoleArn, _, _, _, _ := c.Get("default", "default") if gotRoleArn != roleArn { t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn) } @@ -268,7 +271,7 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(newSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") + gotRoleArn, _, _, _, _ := c.Get("default", "default") if gotRoleArn != "" { t.Errorf("got roleArn %v, expected %q", gotRoleArn, "") } @@ -323,7 +326,7 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("could not find entry that should have been added") } @@ -340,7 +343,7 @@ func TestCachePrecedence(t *testing.T) { } // Removing sa2 from CM, but SA still exists - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("could not find entry that should still exist") } @@ -356,7 +359,7 @@ func TestCachePrecedence(t *testing.T) { c.addSA(sa2) // Neither cache should return any hits now - role, _, _, _ := c.Get("myns2", "mysa2") + role, _, _, _, _ := c.Get("myns2", "mysa2") if role != "" { t.Errorf("found entry that should not exist") } @@ -370,7 +373,7 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") + role, _, _, exp, _ := c.Get("mysa2", "myns2") if role == "" { t.Errorf("cloud not find entry that should have been added") } @@ -422,7 +425,7 @@ func TestRoleArnComposition(t *testing.T) { var roleArn string err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) { - roleArn, _, _, _ = cache.Get("default", "default") + roleArn, _, _, _, _ = cache.Get("default", "default") return roleArn != "", nil }) if err != nil { diff --git a/pkg/cache/fake.go b/pkg/cache/fake.go index 0f5a67869..005bc375c 100644 --- a/pkg/cache/fake.go +++ b/pkg/cache/fake.go @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{} func (f *FakeServiceAccountCache) Start(chan struct{}) {} // Get gets a service account from the cache -func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { f.mu.RLock() defer f.mu.RUnlock() resp, ok := f.cache[namespace+"/"+name] if !ok { - return "", "", false, pkg.DefaultTokenExpiration + return "", "", false, pkg.DefaultTokenExpiration, false } - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true +} + +// GetOrNotify gets a service account from the cache +func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) { + f.mu.RLock() + defer f.mu.RUnlock() + resp, ok := f.cache[namespace+"/"+name] + if !ok { + return "", "", false, pkg.DefaultTokenExpiration, false + } + return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true } func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1425ca14b..a877cd017 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -24,6 +24,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" @@ -77,6 +78,12 @@ func WithAnnotationDomain(domain string) ModifierOpt { return func(m *Modifier) { m.AnnotationDomain = domain } } +// WithSALookupGraceTime sets the grace time to wait for service accounts to appear in cache +func WithSALookupGraceTime(saLookupGraceTime time.Duration) ModifierOpt { + return func(m *Modifier) { m.saLookupGraceTime = saLookupGraceTime } + +} + // NewModifier returns a Modifier with default values func NewModifier(opts ...ModifierOpt) *Modifier { mod := &Modifier{ @@ -101,6 +108,7 @@ type Modifier struct { ContainerCredentialsConfig containercredentials.Config volName string tokenName string + saLookupGraceTime time.Duration } type patchOperation struct { @@ -425,7 +433,24 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { } // Use the STS WebIdentity method if set - roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + handler := make(chan any, 1) + roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler) + key := pod.Namespace + "/" + pod.Spec.ServiceAccountName + if !found && m.saLookupGraceTime > 0 { + klog.Warningf("Service account %q not found in the cache. Waiting up to %s to be notified", key, m.saLookupGraceTime) + select { + case <-handler: + roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) + if !found { + klog.Warningf("Service account %q not found in the cache after being notified. Not mutating.", key) + return nil + } + case <-time.After(m.saLookupGraceTime): + klog.Warningf("Service account %q not found in the cache after %s. Not mutating.", key, m.saLookupGraceTime) + return nil + } + } + klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %q: %s", key, roleArn) if roleArn != "" { tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)