Skip to content

Commit

Permalink
Fix race condition between service account availability and webhook i…
Browse files Browse the repository at this point in the history
…nvocation
  • Loading branch information
roehrijn committed Aug 13, 2024
1 parent ba509d3 commit ea82b2f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 58 deletions.
63 changes: 32 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,37 +143,38 @@ When running a container with a non-root user, you need to give the container ac

```
Usage of amazon-eks-pod-identity-webhook:
--add_dir_header If true, adds the file directory to the header
--alsologtostderr log to standard error as well as files
--annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com")
--aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers
--enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported
--in-cluster Use in-cluster authentication and certificate request API (default true)
--kube-api string (out-of-cluster) The url to the API server
--kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file
--log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0)
--log_dir string If non-empty, write log files in this directory
--log_file string If non-empty, use this log file
--log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800)
--logtostderr log to standard error instead of files (default true)
--metrics-port int Port to listen on for metrics (http) (default 9999)
--namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks")
--port int Port to listen on (default 443)
--service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook")
--skip_headers If true, avoid header prefixes in the log messages
--skip_log_headers If true, avoid headers when opening log files
--stderrthreshold severity logs at or above this threshold go to stderr (default 2)
--sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false.
--tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt")
--tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key")
--tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook")
--token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com")
--token-expiration int The token expiration (default 86400)
--token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount")
-v, --v Level number for the log level verbosity
--version Display the version and exit
--vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging
--watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations
--add_dir_header If true, adds the file directory to the header
--alsologtostderr log to standard error as well as files
--annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com")
--aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers
--enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported
--in-cluster Use in-cluster authentication and certificate request API (default true)
--kube-api string (out-of-cluster) The url to the API server
--kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file
--log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0)
--log_dir string If non-empty, write log files in this directory
--log_file string If non-empty, use this log file
--log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800)
--logtostderr log to standard error instead of files (default true)
--metrics-port int Port to listen on for metrics (http) (default 9999)
--namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks")
--port int Port to listen on (default 443)
--service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook")
--service-account-lookup-grace-period The grace period for service account to be available in cache before not mutating a pod. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance. (default 100ms)
--skip_headers If true, avoid header prefixes in the log messages
--skip_log_headers If true, avoid headers when opening log files
--stderrthreshold severity logs at or above this threshold go to stderr (default 2)
--sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false.
--tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt")
--tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key")
--tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook")
--token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com")
--token-expiration int The token expiration (default 86400)
--token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount")
-v, --v Level number for the log level verbosity
--version Display the version and exit
--vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging
--watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations
```
### AWS_DEFAULT_REGION Injection
Expand Down
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ func main() {

debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported")

saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 100*time.Millisecond, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 100ms. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance.")

klog.InitFlags(goflag.CommandLine)
// Add klog CommandLine flags to pflag CommandLine
goflag.CommandLine.VisitAll(func(f *goflag.Flag) {
Expand Down Expand Up @@ -208,6 +210,7 @@ func main() {
handler.WithServiceAccountCache(saCache),
handler.WithContainerCredentialsConfig(containerCredentialsConfig),
handler.WithRegion(*region),
handler.WithSALookupGraceTime(*saLookupGracePeriod),
)

addr := fmt.Sprintf(":%d", *port)
Expand Down
41 changes: 31 additions & 10 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ type CacheResponse struct {

type ServiceAccountCache interface {
Start(stop chan struct{})
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64)
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64)
// ToJSON returns cache contents as JSON string
ToJSON() string
Expand All @@ -60,6 +61,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -87,41 +89,51 @@ func init() {
// Get will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any
// ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
return c.GetOrNotify(name, namespace, nil)
}

// GetOrNotify will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register
// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check
// for a ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
{
resp := c.getSA(name, namespace)
resp := c.getSAorNotify(name, namespace, handler)
if resp != nil && resp.RoleARN != "" {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
{
resp := c.getCM(name, namespace)
if resp != nil {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}
}
klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name)
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}

// GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials).
// The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility,
// Use these fields if they are set in the sa annotations or config map.
func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
if resp := c.getSA(name, namespace); resp != nil {
if resp := c.getSAorNotify(name, namespace, nil); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
} else if resp := c.getCM(name, namespace); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
}
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse {
func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse {
c.mu.RLock()
defer c.mu.RUnlock()
resp, ok := c.saCache[namespace+"/"+name]
if !ok {
if !ok && handler != nil {
klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name)
c.notificationHandlers[namespace+"/"+name] = handler
return nil
}
return resp
Expand Down Expand Up @@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) {
c.mu.Lock()
defer c.mu.Unlock()
klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp)

key := namespace + "/" + name
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp)
c.saCache[namespace+"/"+name] = resp

if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handler for %q", key)
handler <- 1
delete(c.notificationHandlers, key)
}
}

func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) {
Expand Down Expand Up @@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan any{},
}

saInformer.Informer().AddEventHandler(
Expand Down
29 changes: 16 additions & 13 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@ func TestSaCache(t *testing.T) {
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
}

role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default")
role, aud, useRegionalSTS, tokenExpiration, found := cache.Get("default", "default")

assert.False(t, found, "Expected no cache entry to be found")
if role != "" || aud != "" {
t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration)
}

cache.addSA(testSA)

role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default")
role, aud, useRegionalSTS, tokenExpiration, found = cache.Get("default", "default")

assert.True(t, found, "Expected cache entry to be found")
assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role)
assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud)
assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false")
Expand Down Expand Up @@ -157,7 +159,8 @@ func TestNonRegionalSTS(t *testing.T) {
t.Fatalf("cache never called addSA: %v", err)
}

gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default")
gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default")
assert.True(t, found, "Expected cache entry to be found")
if gotRoleArn != roleArn {
t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn)
}
Expand Down Expand Up @@ -202,7 +205,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -214,7 +217,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -226,7 +229,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")
if role != "" {
t.Errorf("found entry that should have been removed")
}
Expand Down Expand Up @@ -256,7 +259,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(oldSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, _ := c.Get("default", "default")
if gotRoleArn != roleArn {
t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn)
}
Expand All @@ -268,7 +271,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(newSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, _ := c.Get("default", "default")
if gotRoleArn != "" {
t.Errorf("got roleArn %v, expected %q", gotRoleArn, "")
}
Expand Down Expand Up @@ -323,7 +326,7 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("could not find entry that should have been added")
}
Expand All @@ -340,7 +343,7 @@ func TestCachePrecedence(t *testing.T) {
}

// Removing sa2 from CM, but SA still exists
role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("could not find entry that should still exist")
}
Expand All @@ -356,7 +359,7 @@ func TestCachePrecedence(t *testing.T) {
c.addSA(sa2)

// Neither cache should return any hits now
role, _, _, _ := c.Get("myns2", "mysa2")
role, _, _, _, _ := c.Get("myns2", "mysa2")
if role != "" {
t.Errorf("found entry that should not exist")
}
Expand All @@ -370,7 +373,7 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, _ := c.Get("mysa2", "myns2")
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand Down Expand Up @@ -422,7 +425,7 @@ func TestRoleArnComposition(t *testing.T) {

var roleArn string
err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
roleArn, _, _, _ = cache.Get("default", "default")
roleArn, _, _, _, _ = cache.Get("default", "default")
return roleArn != "", nil
})
if err != nil {
Expand Down
17 changes: 14 additions & 3 deletions pkg/cache/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{}
func (f *FakeServiceAccountCache) Start(chan struct{}) {}

// Get gets a service account from the cache
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

// GetOrNotify gets a service account from the cache
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration, false
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
}

func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
Expand Down
Loading

0 comments on commit ea82b2f

Please sign in to comment.