Skip to content

Commit

Permalink
Fix race condition between service account availability and webhook i…
Browse files Browse the repository at this point in the history
…nvocation
  • Loading branch information
roehrijn committed Aug 13, 2024
1 parent ba509d3 commit 79fcca1
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 27 deletions.
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func main() {

debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported")

saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 100*time.Millisecond, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 100ms. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance.")

klog.InitFlags(goflag.CommandLine)
// Add klog CommandLine flags to pflag CommandLine
goflag.CommandLine.VisitAll(func(f *goflag.Flag) {
Expand Down Expand Up @@ -208,6 +210,7 @@ func main() {
handler.WithServiceAccountCache(saCache),
handler.WithContainerCredentialsConfig(containerCredentialsConfig),
handler.WithRegion(*region),
handler.WithSALookupGraceTime(*saLookupGracePeriod),
)

addr := fmt.Sprintf(":%d", *port)
Expand Down
41 changes: 31 additions & 10 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ type CacheResponse struct {

type ServiceAccountCache interface {
Start(stop chan struct{})
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64)
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64)
// ToJSON returns cache contents as JSON string
ToJSON() string
Expand All @@ -60,6 +61,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -87,41 +89,51 @@ func init() {
// Get will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any
// ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
return c.GetOrNotify(name, namespace, nil)
}

// GetOrNotify will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register
// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check
// for a ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
{
resp := c.getSA(name, namespace)
resp := c.getSAorNotify(name, namespace, handler)
if resp != nil && resp.RoleARN != "" {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
{
resp := c.getCM(name, namespace)
if resp != nil {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name)
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}

// GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials).
// The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility,
// Use these fields if they are set in the sa annotations or config map.
func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
if resp := c.getSA(name, namespace); resp != nil {
if resp := c.getSAorNotify(name, namespace, nil); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
} else if resp := c.getCM(name, namespace); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
}
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse {
func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse {
c.mu.RLock()
defer c.mu.RUnlock()
resp, ok := c.saCache[namespace+"/"+name]
if !ok {
if !ok && handler != nil {
klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name)
c.notificationHandlers[namespace+"/"+name] = handler
return nil
}
return resp
Expand Down Expand Up @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) {
c.mu.Lock()
defer c.mu.Unlock()
klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp)

key := namespace + "/" + name
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp)
c.saCache[namespace+"/"+name] = resp

if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handler for %q", key)
handler <- 1
delete(c.notificationHandlers, key)
}
}

func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) {
Expand Down Expand Up @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan any{},
}

saInformer.Informer().AddEventHandler(
Expand Down
29 changes: 16 additions & 13 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@ func TestSaCache(t *testing.T) {
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
}

role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default")
role, aud, useRegionalSTS, tokenExpiration, found := cache.Get("default", "default")

assert.False(t, found, "Expected no cache entry to be found")
if role != "" || aud != "" {
t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration)
}

cache.addSA(testSA)

role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default")
role, aud, useRegionalSTS, tokenExpiration, found = cache.Get("default", "default")

assert.True(t, found, "Expected cache entry to be found")
assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role)
assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud)
assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false")
Expand Down Expand Up @@ -157,7 +159,8 @@ func TestNonRegionalSTS(t *testing.T) {
t.Fatalf("cache never called addSA: %v", err)
}

gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default")
gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default")
assert.True(t, found, "Expected cache entry to be found")
if gotRoleArn != roleArn {
t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn)
}
Expand Down Expand Up @@ -202,7 +205,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -214,7 +217,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -226,7 +229,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role != "" {
t.Errorf("found entry that should have been removed")
}
Expand Down Expand Up @@ -256,7 +259,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(oldSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, _ := c.Get("default", "default")
if gotRoleArn != roleArn {
t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn)
}
Expand All @@ -268,7 +271,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(newSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, _ := c.Get("default", "default")
if gotRoleArn != "" {
t.Errorf("got roleArn %v, expected %q", gotRoleArn, "")
}
Expand Down Expand Up @@ -323,7 +326,7 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("could not find entry that should have been added")
}
Expand All @@ -340,7 +343,7 @@ func TestCachePrecedence(t *testing.T) {
}

// Removing sa2 from CM, but SA still exists
role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("could not find entry that should still exist")
}
Expand All @@ -356,7 +359,7 @@ func TestCachePrecedence(t *testing.T) {
c.addSA(sa2)

// Neither cache should return any hits now
role, _, _, _ := c.Get("myns2", "mysa2")
role, _, _, _, _ := c.Get("myns2", "mysa2")
if role != "" {
t.Errorf("found entry that should not exist")
}
Expand All @@ -370,7 +373,7 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand Down Expand Up @@ -422,7 +425,7 @@ func TestRoleArnComposition(t *testing.T) {

var roleArn string
err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
roleArn, _, _, _ = cache.Get("default", "default")
roleArn, _, _, _, _ = cache.Get("default", "default")
return roleArn != "", nil
})
if err != nil {
Expand Down
17 changes: 14 additions & 3 deletions pkg/cache/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{}
func (f *FakeServiceAccountCache) Start(chan struct{}) {}

// Get gets a service account from the cache
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

// GetOrNotify gets a service account from the cache
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
Expand Down
27 changes: 26 additions & 1 deletion pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"path/filepath"
"strconv"
"strings"
"time"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials"

Expand Down Expand Up @@ -77,6 +78,12 @@ func WithAnnotationDomain(domain string) ModifierOpt {
return func(m *Modifier) { m.AnnotationDomain = domain }
}

// WithSALookupGraceTime sets the grace time to wait for service accounts to appear in cache
func WithSALookupGraceTime(saLookupGraceTime time.Duration) ModifierOpt {
return func(m *Modifier) { m.saLookupGraceTime = saLookupGraceTime }

}

// NewModifier returns a Modifier with default values
func NewModifier(opts ...ModifierOpt) *Modifier {
mod := &Modifier{
Expand All @@ -101,6 +108,7 @@ type Modifier struct {
ContainerCredentialsConfig containercredentials.Config
volName string
tokenName string
saLookupGraceTime time.Duration
}

type patchOperation struct {
Expand Down Expand Up @@ -425,7 +433,24 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {
}

// Use the STS WebIdentity method if set
roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
handler := make(chan any, 1)
roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler)
key := pod.Namespace + "/" + pod.Spec.ServiceAccountName
if !found && m.saLookupGraceTime > 0 {
klog.Warningf("Service account %q not found in the cache. Waiting up to %s to be notified", key, m.saLookupGraceTime)
select {
case <-handler:
roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
if !found {
klog.Warningf("Service account %q not found in the cache after being notified. Not mutating.", key)
return nil
}
case <-time.After(m.saLookupGraceTime):
klog.Warningf("Service account %q not found in the cache after %s. Not mutating.", key, m.saLookupGraceTime)
return nil
}
}
klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %q: %s", key, roleArn)
if roleArn != "" {
tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)

Expand Down

0 comments on commit 79fcca1

Please sign in to comment.