Skip to content

Commit

Permalink
Fix race condition between service account availability and webhook i…
Browse files Browse the repository at this point in the history
…nvocation
  • Loading branch information
roehrijn committed Aug 13, 2024
1 parent ba509d3 commit 37ac32f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
43 changes: 32 additions & 11 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ type CacheResponse struct {

type ServiceAccountCache interface {
Start(stop chan struct{})
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64)
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64)
// ToJSON returns cache contents as JSON string
ToJSON() string
Expand All @@ -60,6 +61,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -87,41 +89,51 @@ func init() {
// Get will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any
// ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
return c.GetOrNotify(name, namespace, nil)
}

// GetOrNotify will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register
// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check
// for a ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
{
resp := c.getSA(name, namespace)
if resp != nil && resp.RoleARN != "" {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
resp := c.getSAorNotify(name, namespace, handler)
if resp != nil {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
{
resp := c.getCM(name, namespace)
if resp != nil {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name)
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}

// GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials).
// The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility,
// Use these fields if they are set in the sa annotations or config map.
func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
if resp := c.getSA(name, namespace); resp != nil {
if resp := c.getSAorNotify(name, namespace, nil); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
} else if resp := c.getCM(name, namespace); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
}
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse {
func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse {
c.mu.RLock()
defer c.mu.RUnlock()
resp, ok := c.saCache[namespace+"/"+name]
if !ok {
if !ok && handler != nil {
klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name)
c.notificationHandlers[namespace+"/"+name] = handler
return nil
}
return resp
Expand Down Expand Up @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) {
c.mu.Lock()
defer c.mu.Unlock()
klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp)

key := namespace + "/" + name
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp)
c.saCache[namespace+"/"+name] = resp

if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handler for %q", key)
handler <- 1
delete(c.notificationHandlers, key)
}
}

func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) {
Expand Down Expand Up @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan any{},
}

saInformer.Informer().AddEventHandler(
Expand Down
17 changes: 14 additions & 3 deletions pkg/cache/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{}
func (f *FakeServiceAccountCache) Start(chan struct{}) {}

// Get gets a service account from the cache
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

// GetOrNotify gets a service account from the cache
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
Expand Down
19 changes: 18 additions & 1 deletion pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials"

Expand Down Expand Up @@ -425,7 +426,23 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {
}

// Use the STS WebIdentity method if set
roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
handler := make(chan any, 1)
roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler)
if !found {
klog.Warningf("Service account %s/%s not found in the cache. Waiting up to 5s to be notified", pod.Namespace, pod.Spec.ServiceAccountName)
select {
case <-handler:
roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
if !found {
klog.Warningf("Service account %s/%s not found in the cache after being notified. Not mutating.", pod.Namespace, pod.Spec.ServiceAccountName)
return nil
}
case <-time.After(5 * time.Second):
klog.Warningf("Service account %s/%s not found in the cache after 5s. Not mutating.", pod.Namespace, pod.Spec.ServiceAccountName)
return nil
}
}
klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %s/%s: %s", pod.Namespace, pod.Spec.ServiceAccountName, roleArn)
if roleArn != "" {
tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)

Expand Down

0 comments on commit 37ac32f

Please sign in to comment.