diff --git a/cmd/metadata-server/tokencache.go b/cmd/metadata-server/tokencache.go index cc4f65f..34020a2 100644 --- a/cmd/metadata-server/tokencache.go +++ b/cmd/metadata-server/tokencache.go @@ -16,6 +16,11 @@ type KnownToken struct { expires time.Time } +type inflightLock struct { + handle *shared.TicketLock + lastUsed time.Time +} + // TokenCache is a cache for previously requested tokens type TokenCache struct { lock *sync.Mutex @@ -25,6 +30,8 @@ type TokenCache struct { hitMetric prometheus.Counter missMetric prometheus.Counter setMetric prometheus.Counter + + inflight map[TokenUID]*inflightLock } // NewTokenCache creates a new token cache with a garbage collection interval. @@ -68,6 +75,7 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache { hitMetric: hitMetric, missMetric: missMetric, setMetric: setMetric, + inflight: make(map[TokenUID]*inflightLock), } if gcInterval > 0 { @@ -86,6 +94,30 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache { return cache } +// GetTokenLock returns a ticket lock for the given token identifier. +// The lock can be used to prevent multiple parallel requests for the same token. +func (t *TokenCache) GetTokenLock(tokenIdentifier TokenLookup) *shared.TicketLock { + t.lock.Lock() + defer t.lock.Unlock() + + id := tokenIdentifier.ToTokenUID() + lock, ok := t.inflight[id] + + if ok { + if time.Since(lock.lastUsed) <= t.minTokenLifetime { + lock.lastUsed = time.Now() + return lock.handle + } + } + + lock = &inflightLock{ + handle: shared.NewTicketLock(5 * time.Millisecond), + lastUsed: time.Now(), + } + t.inflight[id] = lock + return lock.handle +} + // StopGC stops the garbage collection timer. func (t *TokenCache) StopGC() { if t.gcTimer != nil { @@ -110,6 +142,20 @@ func (t *TokenCache) GC() { for _, id := range staleTokens { delete(t.data, id) } + + // Cleanup inflight locks. + // As this is a map, we can delete keys while iterating. + for id, lock := range t.inflight { + if time.Since(lock.lastUsed) > t.minTokenLifetime { + // Locks that are held for longer than minTokenLifetime are a sign + // of a bug, like not releasing the lock. Fetching a token should + // always be _much_ shorter than minTokenLifetime. + if lock.handle.IsLocked() { + log.Warn().Msg("Timed-out inflight lock is still held by a thread, this should not happen") + } + delete(t.inflight, id) + } + } } // Get reurns the known token for the given service account or nil diff --git a/cmd/metadata-server/tokencache_test.go b/cmd/metadata-server/tokencache_test.go index 65d4b39..288d13f 100644 --- a/cmd/metadata-server/tokencache_test.go +++ b/cmd/metadata-server/tokencache_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "testing" "time" @@ -222,3 +223,50 @@ func TestTokenCacheGC(t *testing.T) { assert.Nil(cache.Get(accessTokenId1)) assert.Nil(cache.Get(identityTokenId1)) } + +func TestTokenCacheGetTokenLock(t *testing.T) { + // Test the token cache + assert := assert.New(t) + lifeTime := 200 * time.Millisecond + + // Create a new cache + cache := NewTokenCache(lifeTime/2, time.Minute) + defer cache.StopGC() + + fakeIdentity1 := MockSourceIdentity{ + Name: "test", + BoundGSA: "test@gcp.com", + } + fakeIdentity2 := MockSourceIdentity{ + Name: "test2", + BoundGSA: "test2@gcp.com", + } + + // Generate access tokens and locks + tokenId1 := NewLookup(TokenTypeAccess, fakeIdentity1) + tokenId2 := NewLookup(TokenTypeAccess, fakeIdentity2) + + lock1 := cache.GetTokenLock(tokenId1) + lock2 := cache.GetTokenLock(tokenId2) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // Normal lock + ticket1 := lock1.Lock() + defer lock1.Unlock() + assert.NotZero(ticket1, "lock1 should always return a non-zero ticket") + + // Parallel lock on different tokens should not block + ticket2 := lock2.LockWithContext(ctx) + defer lock2.Unlock() + + assert.NotZero(ticket2, "lock1 should not block lock2") + + // Parallel lock on same token should block and cancel the context + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + + ticket3 := lock1.LockWithContext(ctx2) + assert.Zero(ticket3, "lock2 should return a zero ticket if the context is done") +} diff --git a/cmd/metadata-server/tokenhandlers.go b/cmd/metadata-server/tokenhandlers.go index 0d5fdc4..6aa07cd 100644 --- a/cmd/metadata-server/tokenhandlers.go +++ b/cmd/metadata-server/tokenhandlers.go @@ -1,6 +1,7 @@ package main import ( + "errors" "identity-metadata-server/internal/shared" "net/http" "strings" @@ -48,44 +49,69 @@ func HandleGetAccessToken(c *gin.Context) { } tokenID := NewLookupWithScopeAndAudience(TokenTypeAccess, srcIdentity, scopes, additionalAudiences) - cachedToken := knownTokens.Get(tokenID) - - if cachedToken == nil { - // The documentation is a bit patchy here, so we don't know if we can - // actually override the token lifetime through a request. - // TODO: Reverse-engineering is required here. We need to find a - // call that sets the token lifetime and see which parameter is - // being used. - tokenLifeTime := AccessTokenLifetime - - trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences) - if trt == nil { - shared.HttpError(c, http.StatusInternalServerError, err) - return + returnToken := func(cachedToken *KnownToken) { + // The format is explained in the documentation. + // https://cloud.google.com/compute/docs/access/authenticate-workloads#applications + // The response format is identical to the one used by the STS endpoint, + // which might be by design, but is not documented. + response := shared.TokenExchangeResponse{ + AccessToken: cachedToken.token, + ExpiresIn: int(time.Until(cachedToken.expires).Seconds()), + TokenType: "Bearer", } - // Get the token for the given parameters. - accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa) - if accessToken == nil { - shared.HttpError(c, http.StatusInternalServerError, err) - return - } + c.Header("Metadata-Flavor", "Google") + c.JSON(http.StatusOK, response) + } - cachedToken = knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime) + // Try to get the token from the cache + if cachedToken := knownTokens.Get(tokenID); cachedToken != nil { + returnToken(cachedToken) + return } - // The format is explained in the documentation. - // https://cloud.google.com/compute/docs/access/authenticate-workloads#applications - // The response format is identical to the one used by the STS endpoint, - // which might be by design, but is not documented. - response := shared.TokenExchangeResponse{ - AccessToken: cachedToken.token, - ExpiresIn: int(time.Until(cachedToken.expires).Seconds()), - TokenType: "Bearer", + // Cache miss. Acquire a lock to block inflight requests for the same tokenID. + inflightLock := knownTokens.GetTokenLock(tokenID) + if inflightLock.LockWithContext(c.Request.Context()) == 0 { + c.Header("Retry-After", "5") + shared.HttpError(c, http.StatusTooManyRequests, errors.New("timed out while waiting for another token fetch to finish")) + return } + defer inflightLock.Unlock() - c.Header("Metadata-Flavor", "Google") - c.JSON(http.StatusOK, response) + // Try to get the token from the cache again. This time we might have a token + // in the cache, as another request might have fetched the token while we were + // waiting for the lock. + if cachedToken := knownTokens.Get(tokenID); cachedToken != nil { + returnToken(cachedToken) + return + } + + // True cache miss. Fetch the token from the token provider. + + // The documentation is a bit patchy here, so we don't know if we can + // actually override the token lifetime through a request. + // TODO: Reverse-engineering is required here. We need to find a + // call that sets the token lifetime and see which parameter is + // being used. + tokenLifeTime := AccessTokenLifetime + + trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences) + if trt == nil { + shared.HttpError(c, http.StatusInternalServerError, err) + return + } + + // Get the token for the given parameters. + accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa) + if accessToken == nil { + shared.HttpError(c, http.StatusInternalServerError, err) + return + } + + // Store the token in the cache and return it + cachedToken := knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime) + returnToken(cachedToken) } // HandleGetIdentityToken handles an identity token request. diff --git a/internal/shared/intheap.go b/internal/shared/intheap.go new file mode 100644 index 0000000..de2903c --- /dev/null +++ b/internal/shared/intheap.go @@ -0,0 +1,41 @@ +package shared + +// HeapUint64 implements a min-heap of uint64 values on top of a slice. +// Use this type with the container/heap package. +type HeapUint64 []uint64 + +// Len returns the number of elements in the heap. +func (h HeapUint64) Len() int { + return len(h) +} + +// Less returns true if the element at index i is less than the element at index j. +func (h HeapUint64) Less(i, j int) bool { return h[i] < h[j] } + +// Swap swaps the elements at index i and j. +func (h HeapUint64) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +// Push adds a new element to the heap. +// Use heap.Push(h, x) instead of this function. +func (h *HeapUint64) Push(x any) { + *h = append(*h, x.(uint64)) +} + +// Peek returns the smallest element from the heap without removing it. +// It returns false if the heap is empty. +func (h *HeapUint64) Peek() (uint64, bool) { + if len(*h) == 0 { + return 0, false + } + return (*h)[0], true +} + +// Pop removes and returns the smallest element from the heap. +// Use heap.Pop(h) instead of this function. +func (h *HeapUint64) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/internal/shared/ticketlock.go b/internal/shared/ticketlock.go new file mode 100644 index 0000000..2a57fa9 --- /dev/null +++ b/internal/shared/ticketlock.go @@ -0,0 +1,129 @@ +package shared + +import ( + "container/heap" + "context" + "sync" + "sync/atomic" + "time" +) + +type TicketLock struct { + nextTicket uint64 + activeTicket uint64 + pauseDuration time.Duration + canceledTickets *HeapUint64 + ticketGuard *sync.Mutex +} + +// NewTicketLock creates a new ticket lock with the given granularity. +// The granularity is the time to wait between each lock acquisition check. +// The granularity should be small enough to not block the main thread for too +// long, but large enough to not waste too much time. +// A granularity of 1-5 milliseconds is a good starting point. +// If the granularity is zero or less it will be set to 1 millisecond. +func NewTicketLock(granularity time.Duration) *TicketLock { + if granularity <= 0 { + granularity = time.Millisecond + } + return &TicketLock{ + nextTicket: 1, + activeTicket: 1, + pauseDuration: granularity, + canceledTickets: &HeapUint64{}, + ticketGuard: &sync.Mutex{}, + } +} + +// IsLocked returns true if the lock is currently held by a thread. +// Please note that this status can change right after the call to IsLocked(). +// I.e. this is not a reliable way to check if the lock is currently held by a +// thread. It is only meant for debugging purposes. +func (l *TicketLock) IsLocked() bool { + return atomic.LoadUint64(&l.activeTicket) != atomic.LoadUint64(&l.nextTicket) +} + +// Lock tries to acquire a lock in a FIFO way. +func (l *TicketLock) Lock() uint64 { + return l.LockWithContext(context.Background()) +} + +// LockWithContext tries to acquire a lock in a FIFO way. +// It returns 0 when the lock failed to be acquired due to a context +// cancellation or a timeout. +// If the lock was acquired, it returns the ticket number of the lock. +func (l *TicketLock) LockWithContext(ctx context.Context) uint64 { + ticket := atomic.AddUint64(&l.nextTicket, 1) - 1 + var pause *time.Ticker + + for { + if atomic.LoadUint64(&l.activeTicket) == ticket { + return ticket + } + + // Do a lazy initialization of the ticker to avoid creating a ticker if + // it is not needed. + if pause == nil { + pause = time.NewTicker(l.pauseDuration) + defer pause.Stop() + } + + select { + // We use a ticker to yield the CPU during waiting and to be able to + // check on the context while pausing. + case <-pause.C: + continue + + case <-ctx.Done(): + // We need to keep track of canceled tickets as tickets are linearly + // ordered. If we don't do this, we cannot properly unlock the lock + // in the correct order. + l.ticketGuard.Lock() + defer l.ticketGuard.Unlock() + + if atomic.LoadUint64(&l.activeTicket) == ticket { + // It might happen that we got the lock while waiting for pause + // and the context was done. In this case, we need to return the + // ticket to avoid a deadlock. This puts the task of unlocking + // to the caller, but this is simpler than unlocking here, as we + // already hold the mutex. + return ticket + } + heap.Push(l.canceledTickets, ticket) + return 0 + } + } +} + +// Unlock releases the lock. +func (l *TicketLock) Unlock() { + l.ticketGuard.Lock() + defer l.ticketGuard.Unlock() + + for { + ticket := atomic.AddUint64(&l.activeTicket, 1) + nextCanceledTicket, hasCanceledTickets := l.canceledTickets.Peek() + + switch { + // No canceled tickets, we can return + case !hasCanceledTickets: + return + + // Discard any stale canceled tickets that are less than the current + // ticket. This should not happen, but we handle it to be safe. + case nextCanceledTicket < ticket: + heap.Pop(l.canceledTickets) + + // The last canceled ticket is the same as the current ticket. + // We need to try again with the next ticket (which might also be + // canceled). + case nextCanceledTicket == ticket: + heap.Pop(l.canceledTickets) + + // There are canceled tickets, but the current ticket is smaller than + // the first canceled ticket. + default: + return + } + } +} diff --git a/internal/shared/ticketlock_test.go b/internal/shared/ticketlock_test.go new file mode 100644 index 0000000..e118b78 --- /dev/null +++ b/internal/shared/ticketlock_test.go @@ -0,0 +1,175 @@ +package shared + +import ( + "context" + "slices" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTicketLock(t *testing.T) { + assert := assert.New(t) + + lock := NewTicketLock(time.Millisecond) + + ticket1 := lock.Lock() + assert.NotZero(ticket1, "Lock should return a non-zero ticket") + assert.Equal(uint64(1), ticket1, "Lock should return the first ticket") + assert.True(lock.IsLocked(), "Lock should return a non-zero ticket") + + // Test timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + ticket2 := lock.LockWithContext(ctx) + assert.Zero(ticket2, "LockWithContext should return a zero ticket as it timed out before the lock was acquired") + + // Test release + lock.Unlock() + assert.False(lock.IsLocked(), "Lock should return a zero ticket after the lock was released") + + // Test if release properly increments the active ticket + ticket3 := lock.Lock() + assert.NotZero(ticket3, "Lock should return a non-zero ticket after the previous lock was released") + assert.NotEqual(ticket1, ticket3, "Lock should return a different ticket after the previous lock was released") + assert.Equal(uint64(3), ticket3, "Lock should return the third ticket, as the second lock was aborted") + + // Test if release properly increments the active ticket with consecutive discards + ticket4 := lock.LockWithContext(ctx) + assert.Zero(ticket4, "LockWithContext should return a zero ticket as it timed out before the lock was acquired") + ticket5 := lock.LockWithContext(ctx) + assert.Zero(ticket5, "LockWithContext should return a zero ticket as it timed out before the lock was acquired") + + lock.Unlock() + ticket6 := lock.Lock() + assert.NotZero(ticket6, "Lock should return a non-zero ticket after the previous lock was released") + assert.NotEqual(ticket3, ticket6, "Lock should return a different ticket after the previous lock was released") + assert.Equal(uint64(6), ticket6, "Lock should return the sixth ticket, as the fourth and fifth locks were aborted") +} + +func TestTicketLockConcurrency(t *testing.T) { + assert := assert.New(t) + + lock := NewTicketLock(time.Millisecond) + + wg := sync.WaitGroup{} + wg.Add(10) + done := make(chan struct{}) + order := make([]uint64, 0, 10) + + go func() { + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + ticket := lock.Lock() + order = append(order, ticket) + time.Sleep(2 * time.Millisecond) + lock.Unlock() + }() + } + wg.Wait() + close(done) + }() + + select { + case <-done: + break + case <-time.After(time.Second): + t.Fatal("Test should have finished within 1 second") + } + + assert.True(slices.IsSorted(order), "Locks should be ordered") + assert.Equal(uint64(10), order[len(order)-1], "Last ticket should be equal to the number of runs") + assert.False(lock.IsLocked(), "Lock should not be locked after all locks have been released") +} + +func TestTicketLockConcurrencyWithContext(t *testing.T) { + assert := assert.New(t) + + lock := NewTicketLock(time.Millisecond) + + wg := sync.WaitGroup{} + wg.Add(10) + done := make(chan struct{}) + order := make([]uint64, 0, 10) + + go func() { + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + + for { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + ticket := lock.LockWithContext(ctx) + time.Sleep(5 * time.Millisecond) + if ticket != 0 { + order = append(order, ticket) + lock.Unlock() + return + } + } + }() + } + wg.Wait() + close(done) + }() + + select { + case <-done: + break + case <-time.After(10 * time.Second): + t.Fatal("Test should have finished within 10 seconds") + } + + assert.True(slices.IsSorted(order), "Locks should be ordered") + assert.Greater(order[len(order)-1], uint64(10), "Last ticket should be greater than 10 as locks should be aborted if the context is done") + assert.False(lock.IsLocked(), "Lock should not be locked after all locks have been released") +} + +func TestTicketLockConcurrencyWithContextAndPause(t *testing.T) { + assert := assert.New(t) + + wg := sync.WaitGroup{} + wg.Add(1) + done := make(chan struct{}) + + // Granularity must be larger than the timeout of the context to test the + // behavior. + lock := NewTicketLock(100 * time.Millisecond) + + ticket1 := lock.Lock() + assert.NotZero(ticket1, "Lock should return a non-zero ticket") + + // Unlock while the second lock is waiting for the pause. + // As granularity is larger than this delay the second lock is still waiting + // when the unlock is called. + time.AfterFunc(5*time.Millisecond, func() { + lock.Unlock() + }) + + // Make sure to have a context timeout that is between the granularity and + // the delay of the unlock. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + // If the test case fails, we have a deadlock here + go func() { + defer wg.Done() + ticket2 := lock.LockWithContext(ctx) + assert.NotZero(ticket2, "LockWithContext should return a non-zero ticket as the lock was acquired before the context was done") + lock.Unlock() + close(done) + }() + + select { + case <-done: + break + case <-time.After(1 * time.Second): + t.Fatal("Test should have finished within 1 seconds") + } +}