Skip to content

Commit

Permalink
Add Token Caching (#115)
Browse files Browse the repository at this point in the history
Transpires that validating tokens is very expensive, so add a cache to
identity to speed this up.
  • Loading branch information
spjmurray authored Sep 5, 2024
1 parent 5933b1b commit 72eba28
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 7 deletions.
4 changes: 2 additions & 2 deletions charts/identity/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ description: A Helm chart for deploying Unikorn's IdP

type: application

version: v0.2.35
appVersion: v0.2.35
version: v0.2.36
appVersion: v0.2.36

icon: https://raw.githubusercontent.com/unikorn-cloud/assets/main/images/logos/dark-on-light/icon.png

Expand Down
20 changes: 20 additions & 0 deletions pkg/middleware/openapi/remote/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/getkin/kin-openapi/openapi3filter"
Expand All @@ -33,6 +34,8 @@ import (
"github.com/unikorn-cloud/identity/pkg/middleware/openapi"
identityapi "github.com/unikorn-cloud/identity/pkg/openapi"

"k8s.io/apimachinery/pkg/util/cache"

"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -41,6 +44,9 @@ type Authorizer struct {
client client.Client
options *identityclient.Options
clientOptions *coreclient.HTTPClientOptions
// tokenCache is used to enhance interaction as the validation is a
// very expensive operation.
tokenCache *cache.LRUExpireCache
}

var _ openapi.Authorizer = &Authorizer{}
Expand All @@ -51,6 +57,9 @@ func NewAuthorizer(client client.Client, options *identityclient.Options, client
client: client,
options: options,
clientOptions: clientOptions,
// TODO: make this configurable, possibly even a shared flag with the
// authorizer to maintain consistency.
tokenCache: cache.NewLRUExpireCache(4096),
}
}

Expand Down Expand Up @@ -121,6 +130,15 @@ func (a *Authorizer) authorizeOAuth2(r *http.Request) (string, *identityapi.User
return "", nil, errors.OAuth2InvalidRequest("authorization scheme not allowed").WithValues("scheme", authorizationScheme)
}

if value, ok := a.tokenCache.Get(rawToken); ok {
claims, ok := value.(*identityapi.Userinfo)
if !ok {
return "", nil, errors.OAuth2ServerError("invalid token cache data")
}

return rawToken, claims, nil
}

// The identity client neatly wraps up TLS...
identity := identityclient.New(a.client, a.options, a.clientOptions)

Expand Down Expand Up @@ -171,6 +189,8 @@ func (a *Authorizer) authorizeOAuth2(r *http.Request) (string, *identityapi.User
return "", nil, errors.OAuth2ServerError("failed to extrac user information").WithError(err)
}

a.tokenCache.Add(rawToken, claims, time.Until(time.Unix(int64(*claims.Exp), 0)))

return rawToken, claims, nil
}

Expand Down
22 changes: 17 additions & 5 deletions pkg/oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ import (
"github.com/unikorn-cloud/identity/pkg/rbac"
"github.com/unikorn-cloud/identity/pkg/util"

"k8s.io/apimachinery/pkg/util/cache"

"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
)
Expand All @@ -71,13 +73,18 @@ type Options struct {
// lifetime so we can "guarantee" ours will expire before theirs and force
// a refresh before any errors can come from the IdP.
TokenLeewayDuration time.Duration

// TokenCacheSize is used to control the size of the LRU cache for token validation
// checks. This bounds the memory use to prevent DoS attacks.
TokenCacheSize int
}

func (o *Options) AddFlags(f *pflag.FlagSet) {
f.DurationVar(&o.AccessTokenDuration, "access-token-duration", time.Hour, "Maximum time an access token can be active for.")
f.DurationVar(&o.RefreshTokenDuration, "refresh-token-duration", 0, "Maximum time a refresh token can be active for.")
f.DurationVar(&o.TokenVerificationLeeway, "token-verification-leeway", 0, "How mush leeway to permit for verification of token validity.")
f.DurationVar(&o.TokenLeewayDuration, "token-leeway", time.Minute, "How long to remove from the provider token expiry to account for network and processing latency.")
f.IntVar(&o.TokenCacheSize, "token-cache-size", 8192, "How many token cache entries to allow.")
}

// Authenticator provides Keystone authentication functionality.
Expand All @@ -92,17 +99,22 @@ type Authenticator struct {
issuer *jose.JWTIssuer

rbac *rbac.RBAC

// tokenCache is used to enhance interaction as the validation is a
// very expensive operation.
tokenCache *cache.LRUExpireCache
}

// New returns a new authenticator with required fields populated.
// You must call AddFlags after this.
func New(options *Options, namespace string, client client.Client, issuer *jose.JWTIssuer, rbac *rbac.RBAC) *Authenticator {
return &Authenticator{
options: options,
namespace: namespace,
client: client,
issuer: issuer,
rbac: rbac,
options: options,
namespace: namespace,
client: client,
issuer: issuer,
rbac: rbac,
tokenCache: cache.NewLRUExpireCache(options.TokenCacheSize),
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/oauth2/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestTokens(t *testing.T) {
AccessTokenDuration: accessTokenDuration,
RefreshTokenDuration: refreshTokenDuration,
TokenLeewayDuration: accessTokenDuration,
TokenCacheSize: 1024,
}

authenticator := oauth2.New(options, josetesting.Namespace, client, issuer, nil)
Expand Down
15 changes: 15 additions & 0 deletions pkg/oauth2/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ type VerifyInfo struct {

// Verify checks the access token parses and validates.
func (a *Authenticator) Verify(ctx context.Context, info *VerifyInfo) (*AccessTokenClaims, error) {
// The verification process is very expensive, so we add a cache in here to
// improve interactivity. Once this is in place, then the network latency becomes
// the bottle neck, presumably this is the TLS handshake. Similar code can be
// in the remote client-side verification middleware.
if value, ok := a.tokenCache.Get(info.Token); ok {
claims, ok := value.(*AccessTokenClaims)
if !ok {
return nil, fmt.Errorf("%w: failed to assert cache claims", ErrTokenVerification)
}

return claims, nil
}

// Parse and verify the claims with the public key.
claims := &AccessTokenClaims{}

Expand All @@ -218,5 +231,7 @@ func (a *Authenticator) Verify(ctx context.Context, info *VerifyInfo) (*AccessTo
return nil, fmt.Errorf("failed to validate claims: %w", err)
}

a.tokenCache.Add(info.Token, claims, time.Until(claims.Expiry.Time()))

return claims, nil
}

0 comments on commit 72eba28

Please sign in to comment.