Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable user to configure custom Access/Refresh token #5087

Merged
merged 5 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG-6.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic
Versioning](http://semver.org/spec/v2.0.0.html).
## [6.13.0] - Unreleased

### Added
- Added `access-token-expiry` (in minutes) backend configuration variable to control expiry of access token.
- Added `refresh-token-expiry` (in minutes) backend configuration variable to control expiry of refresh token.

## [6.12.0] - 2024-11-13

Expand Down
35 changes: 30 additions & 5 deletions backend/api/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"errors"
"fmt"

corev2 "github.com/sensu/core/v2"
"time"

"github.com/sensu/sensu-go/backend/authentication"
"github.com/sensu/sensu-go/backend/authentication/jwt"
Expand Down Expand Up @@ -51,8 +51,14 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
claims.Issuer = issuer.(string)
}

// append configured access token expiry to claims
var accessTokenExpiry time.Duration
if accessTokenExp := ctx.Value("accessTokenExpiry"); accessTokenExp != nil {
accessTokenExpiry = accessTokenExp.(time.Duration)
}

// Create an access token and its signed version
_, tokenString, err := jwt.AccessToken(claims)
_, tokenString, err := jwt.AccessToken(claims, jwt.WithAccessTokenExpiry(accessTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating access token: %s", err)
}
Expand All @@ -62,7 +68,14 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
refreshToken, refreshTokenString, err := jwt.RefreshToken(refreshClaims)

// append configured refresh token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

refreshToken, refreshTokenString, err := jwt.RefreshToken(refreshClaims, jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating refresh token: %s", err)
}
Expand Down Expand Up @@ -198,18 +211,30 @@ func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.
claims.Issuer = issuer.(string)
}

// append configured access token expiry to claims
var accessTokenExpiry time.Duration
if accessTokenExp := ctx.Value("accessTokenExpiry"); accessTokenExp != nil {
accessTokenExpiry = accessTokenExp.(time.Duration)
}

// Issue a new access token
_, newAccessTokenString, err := jwt.AccessToken(claims)
_, newAccessTokenString, err := jwt.AccessToken(claims, jwt.WithAccessTokenExpiry(accessTokenExpiry))
if err != nil {
return nil, err
}

// append configured refresh token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

// Create a new refresh token, carrying over the session ID
newRefreshClaims := &corev2.Claims{
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
newRefreshToken, newRefreshTokenString, err := jwt.RefreshToken(newRefreshClaims)
newRefreshToken, newRefreshTokenString, err := jwt.RefreshToken(newRefreshClaims, jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
if err != nil {
return nil, fmt.Errorf("error creating refresh token: %s", err)
}
Expand Down
14 changes: 13 additions & 1 deletion backend/api/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"

corev2 "github.com/sensu/core/v2"
"github.com/sensu/sensu-go/backend/authentication"
Expand Down Expand Up @@ -33,6 +34,10 @@ func contextWithClaims(claims *corev2.Claims) context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, corev2.AccessTokenClaims, claims)
ctx = context.WithValue(ctx, corev2.RefreshTokenClaims, refreshClaims)

ctx = context.WithValue(ctx, "accessTokenExpiry", 5*time.Minute)
ctx = context.WithValue(ctx, "refreshTokenExpiry", 12*time.Hour)

return ctx
}

Expand Down Expand Up @@ -205,7 +210,14 @@ func TestRefreshAccessToken(t *testing.T) {
Authenticator: defaultAuth,
Context: func(claims *corev2.Claims) (context.Context, string) {
ctx := contextWithClaims(claims)
refreshToken, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims))

// append configured access token expiry to claims
var refreshTokenExpiry time.Duration
if refreshTokenExp := ctx.Value("refreshTokenExpiry"); refreshTokenExp != nil {
refreshTokenExpiry = refreshTokenExp.(time.Duration)
}

refreshToken, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims), jwt.WithRefreshTokenExpiry(refreshTokenExpiry))
refreshTokenClaims, _ := jwt.GetClaims(refreshToken)
ctx = context.WithValue(ctx, corev2.RefreshTokenString, refreshTokenString)
return ctx, refreshTokenClaims.Id
Expand Down
9 changes: 8 additions & 1 deletion backend/apid/apid.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type APId struct {

serveWaitTime time.Duration
ready func()

AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}

// Option is a functional option.
Expand All @@ -81,6 +84,8 @@ type Config struct {
ClusterVersion string
GraphQLService *graphql.Service
HealthRouter *routers.HealthRouter
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}

// New creates a new APId.
Expand All @@ -102,6 +107,8 @@ func New(c Config, opts ...Option) (*APId, error) {
clusterVersion: c.ClusterVersion,
RequestLimit: c.RequestLimit,
serveWaitTime: c.ServeWaitTime,
AccessTokenExpiry: c.AccessTokenExpiry,
RefreshTokenExpiry: c.RefreshTokenExpiry,
}

// prepare TLS config
Expand Down Expand Up @@ -174,7 +181,7 @@ func AuthenticationSubrouter(router *mux.Router, cfg Config) *mux.Router {
)

mountRouters(subrouter,
routers.NewAuthenticationRouter(cfg.Store, cfg.Authenticator),
routers.NewAuthenticationRouter(cfg.Store, cfg.Authenticator, cfg.AccessTokenExpiry, cfg.RefreshTokenExpiry),
)

return subrouter
Expand Down
19 changes: 15 additions & 4 deletions backend/apid/routers/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net/http"
"time"

"github.com/sensu/sensu-go/backend/authentication/jwt"

Expand All @@ -17,13 +18,15 @@ import (

// AuthenticationRouter handles authentication related requests
type AuthenticationRouter struct {
store store.Store
authenticator *authentication.Authenticator
store store.Store
authenticator *authentication.Authenticator
accessTokenExpiry time.Duration
refreshTokenExpiry time.Duration
}

// NewAuthenticationRouter instantiates new router.
func NewAuthenticationRouter(store store.Store, authenticator *authentication.Authenticator) *AuthenticationRouter {
return &AuthenticationRouter{store: store, authenticator: authenticator}
func NewAuthenticationRouter(store store.Store, authenticator *authentication.Authenticator, accessTokenExpiry time.Duration, refreshTokenExpiry time.Duration) *AuthenticationRouter {
return &AuthenticationRouter{store: store, authenticator: authenticator, accessTokenExpiry: accessTokenExpiry, refreshTokenExpiry: refreshTokenExpiry}
}

// Mount the authentication routes on given mux.Router.
Expand All @@ -47,6 +50,10 @@ func (a *AuthenticationRouter) login(w http.ResponseWriter, r *http.Request) {
// issuer URL
ctx := context.WithValue(r.Context(), jwt.IssuerURLKey, issuerURL(r))

// Not very efficient, but acceptable for simple use cases, ideally we should create a struct and pass the struct
ctx = context.WithValue(ctx, "accessTokenExpiry", a.accessTokenExpiry)
ctx = context.WithValue(ctx, "refreshTokenExpiry", a.refreshTokenExpiry)

client := api.NewAuthenticationClient(a.authenticator, a.store)
tokens, err := client.CreateAccessToken(ctx, username, password)
if err != nil {
Expand Down Expand Up @@ -106,6 +113,10 @@ func (a *AuthenticationRouter) token(w http.ResponseWriter, r *http.Request) {
// issuer URL
ctx := context.WithValue(r.Context(), jwt.IssuerURLKey, issuerURL(r))

// Not very efficient, but acceptable for simple use cases, ideally we should create a struct and pass the struct
ctx = context.WithValue(ctx, "accessTokenExpiry", a.accessTokenExpiry)
ctx = context.WithValue(ctx, "refreshTokenExpiry", a.refreshTokenExpiry)

tokens, err := client.RefreshAccessToken(ctx)
if err != nil {
if err == corev2.ErrInvalidToken {
Expand Down
56 changes: 52 additions & 4 deletions backend/authentication/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ const (
IssuerURLKey key = iota
)

// ExpiryOptions Functional Options Pattern
// ExpiryOptions: Define a struct for optional parameters.
type ExpiryOptions struct {
RefreshTokenExpiry time.Duration
AccessTokenExpiry time.Duration
}

// ExpiryOption Define a functional option type.
type ExpiryOption func(options *ExpiryOptions)

// WithRefreshTokenExpiry for setting refresh token expiry
func WithRefreshTokenExpiry(expiry time.Duration) ExpiryOption {
return func(o *ExpiryOptions) {
o.RefreshTokenExpiry = expiry
}
}

// WithAccessTokenExpiry for setting access token expiry
func WithAccessTokenExpiry(expiry time.Duration) ExpiryOption {
return func(o *ExpiryOptions) {
o.AccessTokenExpiry = expiry
}
}

var (
DefaultAccessTokenLifespan = 5 * time.Minute
defaultRefreshTokenLifespan = 12 * time.Hour
Expand All @@ -49,16 +73,27 @@ func init() {

// AccessToken creates a new access token and returns it in both JWT and
// signed format, along with any error
func AccessToken(claims *corev2.Claims) (*jwt.Token, string, error) {
func AccessToken(claims *corev2.Claims, options ...ExpiryOption) (*jwt.Token, string, error) {
// Create a unique identifier for the token
jti, err := GenJTI()
if err != nil {
return nil, "", err
}
claims.Id = jti

// Default options.
opts := ExpiryOptions{
RefreshTokenExpiry: defaultRefreshTokenLifespan,
AccessTokenExpiry: DefaultAccessTokenLifespan,
}

// Apply functional options.
for _, option := range options {
option(&opts)
}

// Add an expiration to the token
claims.ExpiresAt = time.Now().Add(DefaultAccessTokenLifespan).Unix()
claims.ExpiresAt = time.Now().Add(opts.AccessTokenExpiry).Unix()

token := jwt.NewWithClaims(signingMethod, claims)

Expand Down Expand Up @@ -246,18 +281,31 @@ func parseToken(tokenString string) (*jwt.Token, error) {
}

// RefreshToken returns a refresh token for a specific user
func RefreshToken(claims *corev2.Claims) (*jwt.Token, string, error) {
func RefreshToken(claims *corev2.Claims, options ...ExpiryOption) (*jwt.Token, string, error) {
// Create a unique identifier for the token
jti, err := GenJTI()
if err != nil {
return nil, "", err
}
claims.Id = jti

// Default options.
opts := ExpiryOptions{
RefreshTokenExpiry: defaultRefreshTokenLifespan,
AccessTokenExpiry: DefaultAccessTokenLifespan,
}

// Apply functional options.
for _, option := range options {
option(&opts)
}

// Add an expiration to the token
claims.ExpiresAt = time.Now().Add(opts.RefreshTokenExpiry).Unix()

// Add issuance and expiration timestamps to the token
now := time.Now()
claims.IssuedAt = now.Unix()
claims.ExpiresAt = now.Add(defaultRefreshTokenLifespan).Unix()

token := jwt.NewWithClaims(signingMethod, claims)

Expand Down
2 changes: 2 additions & 0 deletions backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ func Initialize(ctx context.Context, config *Config) (*Backend, error) {
ClusterVersion: clusterVersion,
GraphQLService: b.GraphQLService,
HealthRouter: b.HealthRouter,
AccessTokenExpiry: config.AccessTokenExpiry,
RefreshTokenExpiry: config.RefreshTokenExpiry,
}
newApi, err := apid.New(b.APIDConfig)
if err != nil {
Expand Down
31 changes: 23 additions & 8 deletions backend/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ const (
flagMaxSilencedExpiryTimeAllowed = "max-silenced-expiry-time-allowed"
flagDefaultSilencedExpiryTime = "default-silenced-expiry-time"

// access token and refresh token expiry time
flagAccessTokenExpiry = "access-token-expiry"
flagRefreshTokenExpiry = "refresh-token-expiry"

// Etcd flag constants
flagEtcdClientURLs = "etcd-client-urls"
flagEtcdListenClientURLs = "etcd-listen-client-urls"
Expand Down Expand Up @@ -293,6 +297,9 @@ func StartCommand(initialize InitializeFunc) *cobra.Command {
EventLogBufferWait: viper.GetDuration(flagEventLogBufferWait),
EventLogFile: viper.GetString(flagEventLogFile),
EventLogParallelEncoders: viper.GetBool(flagEventLogParallelEncoders),

AccessTokenExpiry: viper.GetDuration(flagAccessTokenExpiry),
RefreshTokenExpiry: viper.GetDuration(flagRefreshTokenExpiry),
}

if flag := cmd.Flags().Lookup(flagLabels); flag != nil && flag.Changed {
Expand Down Expand Up @@ -455,12 +462,16 @@ func handleConfig(cmd *cobra.Command, arguments []string, server bool) error {
viper.SetDefault(flagEventLogBufferSize, 100000)
viper.SetDefault(flagEventLogFile, "")
viper.SetDefault(flagEventLogParallelEncoders, false)

// default silenced value are set for 1 day = 1440m
viper.SetDefault(flagMaxSilencedExpiryTimeAllowed, "1440m")
viper.SetDefault(flagDefaultSilencedExpiryTime, "1440m")
}

// default silenced value are set for 1 day = 1440m
viper.SetDefault(flagMaxSilencedExpiryTimeAllowed, "1440m")
viper.SetDefault(flagDefaultSilencedExpiryTime, "1440m")

// Access/Refresh token default expiry values
viper.SetDefault(flagAccessTokenExpiry, "5m")
viper.SetDefault(flagRefreshTokenExpiry, "720m")

// Etcd defaults
viper.SetDefault(flagEtcdAdvertiseClientURLs, defaultEtcdAdvertiseClientURL)
viper.SetDefault(flagEtcdListenClientURLs, defaultEtcdClientURL)
Expand Down Expand Up @@ -552,6 +563,14 @@ func flagSet(server bool) *pflag.FlagSet {
flagSet.String(flagEtcdClientURLs, viper.GetString(flagEtcdClientURLs), "client URLs to use when operating as an etcd client")
_ = flagSet.SetAnnotation(flagEtcdClientURLs, "categories", []string{"store"})

// silenced configuration flags
flagSet.Duration(flagDefaultSilencedExpiryTime, viper.GetDuration(flagDefaultSilencedExpiryTime), "Default expiry time for silenced if not set in minutes")
flagSet.Duration(flagMaxSilencedExpiryTimeAllowed, viper.GetDuration(flagMaxSilencedExpiryTimeAllowed), "Maximum expiry time allowed for silenced in minutes")

// Access/Token configuration flags
flagSet.Duration(flagAccessTokenExpiry, viper.GetDuration(flagAccessTokenExpiry), "Set Access Token expiry in minutes")
flagSet.Duration(flagRefreshTokenExpiry, viper.GetDuration(flagRefreshTokenExpiry), "Set Refresh Token expiry in minutes")

if server {
// Main Flags
flagSet.String(flagAgentHost, viper.GetString(flagAgentHost), "agent listener host")
Expand Down Expand Up @@ -594,10 +613,6 @@ func flagSet(server bool) *pflag.FlagSet {
flagSet.Duration(flagPlatformMetricsLoggingInterval, viper.GetDuration(flagPlatformMetricsLoggingInterval), "platform metrics logging interval")
flagSet.String(flagPlatformMetricsLogFile, viper.GetString(flagPlatformMetricsLogFile), "platform metrics log file path")

// silenced configuration flags
flagSet.Duration(flagDefaultSilencedExpiryTime, viper.GetDuration(flagDefaultSilencedExpiryTime), "Default expiry time for silenced if not set in minutes")
flagSet.Duration(flagMaxSilencedExpiryTimeAllowed, viper.GetDuration(flagMaxSilencedExpiryTimeAllowed), "Maximum expiry time allowed for silenced in minutes")

// Etcd server flags
flagSet.StringSlice(flagEtcdPeerURLs, viper.GetStringSlice(flagEtcdPeerURLs), "list of URLs to listen on for peer traffic")
_ = flagSet.SetAnnotation(flagEtcdPeerURLs, "categories", []string{"store"})
Expand Down
4 changes: 4 additions & 0 deletions backend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,8 @@ type Config struct {
// expiry setting for silences
DefaultSilencedExpiryTime time.Duration
MaxSilencedExpiryTimeAllowed time.Duration

// Access/Refresh Token Expiry in Minutes
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
}
Loading