Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cmd/metadata-server/tokencache.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type TokenCache struct {
hitMetric prometheus.Counter
missMetric prometheus.Counter
setMetric prometheus.Counter

inflight map[TokenUID]*shared.TicketLock
Comment thread
arnecls marked this conversation as resolved.
Outdated
}

// NewTokenCache creates a new token cache with a garbage collection interval.
Expand Down Expand Up @@ -68,6 +70,7 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache {
hitMetric: hitMetric,
missMetric: missMetric,
setMetric: setMetric,
inflight: make(map[TokenUID]*shared.TicketLock),
}

if gcInterval > 0 {
Expand All @@ -86,6 +89,21 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache {
return cache
}

// GetTokenLock returns a ticket lock for the given token identifier.
// The lock can be used to prevent multiple parallel requests for the same token.
func (t *TokenCache) GetTokenLock(tokenIdentifier TokenLookup) *shared.TicketLock {
Comment thread
arnecls marked this conversation as resolved.
t.lock.Lock()
defer t.lock.Unlock()

id := tokenIdentifier.ToTokenUID()
lock, ok := t.inflight[id]
if !ok {
lock = shared.NewTicketLock(10 * time.Millisecond)
t.inflight[id] = lock
Comment thread
arnecls marked this conversation as resolved.
Outdated
}
return lock
}
Comment thread
arnecls marked this conversation as resolved.
Comment thread
arnecls marked this conversation as resolved.

// StopGC stops the garbage collection timer.
func (t *TokenCache) StopGC() {
if t.gcTimer != nil {
Expand Down
89 changes: 58 additions & 31 deletions cmd/metadata-server/tokenhandlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"identity-metadata-server/internal/shared"
"net/http"
"strings"
Expand Down Expand Up @@ -48,44 +49,70 @@ func HandleGetAccessToken(c *gin.Context) {
}

tokenID := NewLookupWithScopeAndAudience(TokenTypeAccess, srcIdentity, scopes, additionalAudiences)
cachedToken := knownTokens.Get(tokenID)

if cachedToken == nil {
// The documentation is a bit patchy here, so we don't know if we can
// actually override the token lifetime through a request.
// TODO: Reverse-engineering is required here. We need to find a
// call that sets the token lifetime and see which parameter is
// being used.
tokenLifeTime := AccessTokenLifetime

trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences)
if trt == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
returnToken := func(cachedToken *KnownToken) {
// The format is explained in the documentation.
// https://cloud.google.com/compute/docs/access/authenticate-workloads#applications
// The response format is identical to the one used by the STS endpoint,
// which might be by design, but is not documented.
response := shared.TokenExchangeResponse{
AccessToken: cachedToken.token,
ExpiresIn: int(time.Until(cachedToken.expires).Seconds()),
TokenType: "Bearer",
}

// Get the token for the given parameters.
accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa)
if accessToken == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}
c.Header("Metadata-Flavor", "Google")
c.JSON(http.StatusOK, response)
}

cachedToken = knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime)
// Try to get the token from the cache
if cachedToken := knownTokens.Get(tokenID); cachedToken != nil {
returnToken(cachedToken)
return
}

// The format is explained in the documentation.
// https://cloud.google.com/compute/docs/access/authenticate-workloads#applications
// The response format is identical to the one used by the STS endpoint,
// which might be by design, but is not documented.
response := shared.TokenExchangeResponse{
AccessToken: cachedToken.token,
ExpiresIn: int(time.Until(cachedToken.expires).Seconds()),
TokenType: "Bearer",
// Cache miss. Acquire a lock to block inflight requests for the same tokenID
// Block inflight requests for the same tokenID
Comment thread
arnecls marked this conversation as resolved.
Outdated
inflightLock := knownTokens.GetTokenLock(tokenID)
if inflightLock.LockWithContext(c.Request.Context()) == 0 {
c.Header("Retry-After", "5")
shared.HttpError(c, http.StatusTooManyRequests, errors.New("timed out while waiting for another token fetch to finish"))
return
}
Comment thread
arnecls marked this conversation as resolved.
defer inflightLock.Unlock()

c.Header("Metadata-Flavor", "Google")
c.JSON(http.StatusOK, response)
// Try to get the token from the cache again. This time we might have a token
// in the cache, as another request might have fetched the token while we were
// waiting for the lock.
if cachedToken := knownTokens.Get(tokenID); cachedToken != nil {
returnToken(cachedToken)
return
}

// True cache miss. Fetch the token from the token provider.

// The documentation is a bit patchy here, so we don't know if we can
// actually override the token lifetime through a request.
// TODO: Reverse-engineering is required here. We need to find a
// call that sets the token lifetime and see which parameter is
// being used.
tokenLifeTime := AccessTokenLifetime

trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences)
if trt == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}

// Get the token for the given parameters.
accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa)
if accessToken == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}

// Store the token in the cache and return it
cachedToken := knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime)
returnToken(cachedToken)
}

// HandleGetIdentityToken handles an identity token request.
Expand Down
41 changes: 41 additions & 0 deletions internal/shared/intheap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package shared

// HeapUint64 implements a min-heap of uint64 values on top of a slice.
// Use this type with the container/heap package.
type HeapUint64 []uint64

// Len returns the number of elements in the heap.
func (h HeapUint64) Len() int {
return len(h)
}

// Less returns true if the element at index i is less than the element at index j.
func (h HeapUint64) Less(i, j int) bool { return h[i] < h[j] }

// Swap swaps the elements at index i and j.
func (h HeapUint64) Swap(i, j int) { h[i], h[j] = h[j], h[i] }

// Push adds a new element to the heap.
// Use heap.Push(h, x) instead of this function.
func (h *HeapUint64) Push(x any) {
*h = append(*h, x.(uint64))
}

// Peek returns the smallest element from the heap without removing it.
// It returns false if the heap is empty.
func (h *HeapUint64) Peek() (uint64, bool) {
if len(*h) == 0 {
return 0, false
}
return (*h)[0], true
}

// Pop removes and returns the smallest element from the heap.
// Use heap.Pop(h) instead of this function.
func (h *HeapUint64) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
93 changes: 93 additions & 0 deletions internal/shared/ticketlock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package shared

import (
"container/heap"
"context"
"sync"
"sync/atomic"
"time"
)

type TicketLock struct {
nextTicket uint64
activeTicket uint64
granularity time.Duration
canceledTickets *HeapUint64
ticketGuard *sync.Mutex
}

// NewTicketLock creates a new ticket lock with the given granularity.
// The granularity is the time to wait between each lock acquisition check.
// The granularity should be small enough to not block the main thread for too
// long, but large enough to not waste too much time.
// A granularity of 5-10 milliseconds is a good starting point.
func NewTicketLock(granularity time.Duration) *TicketLock {
return &TicketLock{
nextTicket: 1,
activeTicket: 1,
granularity: granularity,
canceledTickets: &HeapUint64{},
ticketGuard: &sync.Mutex{},
}
}
Comment thread
arnecls marked this conversation as resolved.

// Lock tries to aquire a lock in a FIFO way.
func (l *TicketLock) Lock() uint64 {
return l.LockWithContext(context.Background())
}

// LockWithContext tries to aquire a lock in a FIFO way.
// It returns 0 when the lock failed to be acquired due to a context
// cancellation or a timeout.
// if the lock was acquired, it returns the ticket number of the lock.
Comment thread
arnecls marked this conversation as resolved.
Outdated
Comment thread
arnecls marked this conversation as resolved.
Outdated
func (l *TicketLock) LockWithContext(ctx context.Context) uint64 {
ticket := atomic.AddUint64(&l.nextTicket, 1) - 1

for {
if atomic.LoadUint64(&l.activeTicket) == ticket {
return ticket
}

select {
case <-time.After(l.granularity):
continue
Comment thread
arnecls marked this conversation as resolved.

case <-ctx.Done():
// We need to keep track of canceled tickets as tickets are linearly
// ordered. If we don't do this, we cannot properly unlock the lock
// in the correct order.
l.ticketGuard.Lock()
defer l.ticketGuard.Unlock()
heap.Push(l.canceledTickets, ticket)
Comment thread
arnecls marked this conversation as resolved.
return 0
}
}
Comment thread
arnecls marked this conversation as resolved.
}

// Unlock releases the lock.
func (l *TicketLock) Unlock() {
l.ticketGuard.Lock()
defer l.ticketGuard.Unlock()

for {
ticket := atomic.AddUint64(&l.activeTicket, 1)
lastCanceledTicket, hasCanceledTickets := l.canceledTickets.Peek()

switch {
// No canceled tickets, we can return
case !hasCanceledTickets:
return

// The last canceled ticket is the same as the current ticket.
// We need to try again with the next ticket (which might also be
// canceled).
case lastCanceledTicket == ticket:
heap.Pop(l.canceledTickets)

// There are canceled tickets, but the current ticket is smaller than
// the first canceled ticket.
Comment thread
arnecls marked this conversation as resolved.
Outdated
default:
return
Comment thread
arnecls marked this conversation as resolved.
}
}
}
47 changes: 47 additions & 0 deletions internal/shared/ticketlock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package shared

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTicketLock(t *testing.T) {
Comment thread
arnecls marked this conversation as resolved.
assert := assert.New(t)

lock := NewTicketLock(time.Millisecond)
Comment thread
arnecls marked this conversation as resolved.
Comment thread
arnecls marked this conversation as resolved.

ticket1 := lock.Lock()
assert.NotZero(ticket1, "Lock should return a non-zero ticket")
assert.Equal(uint64(1), ticket1, "Lock should return the first ticket")

// Test timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
Comment thread
arnecls marked this conversation as resolved.

ticket2 := lock.LockWithContext(ctx)
assert.Zero(ticket2, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")

// Test release
lock.Unlock()

// Test if release properly increments the active ticket
ticket3 := lock.Lock()
assert.NotZero(ticket3, "Lock should return a non-zero ticket after the previous lock was released")
assert.NotEqual(ticket1, ticket3, "Lock should return a different ticket after the previous lock was released")
assert.Equal(uint64(3), ticket3, "Lock should return the third ticket, as the second lock was aborted")

// Test if release properly increments the active ticket with consecutive discards
ticket4 := lock.LockWithContext(ctx)
assert.Zero(ticket4, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")
ticket5 := lock.LockWithContext(ctx)
Comment thread
arnecls marked this conversation as resolved.
assert.Zero(ticket5, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")

lock.Unlock()
ticket6 := lock.Lock()
assert.NotZero(ticket6, "Lock should return a non-zero ticket after the previous lock was released")
assert.NotEqual(ticket3, ticket6, "Lock should return a different ticket after the previous lock was released")
assert.Equal(uint64(6), ticket6, "Lock should return the sixth ticket, as the fourth and fifth lock was aborted")
Comment thread
arnecls marked this conversation as resolved.
Outdated
}
Loading