Skip to content

Added support for type parameters in the ParseXXX functions #271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: [1.17, 1.18, 1.19]
go: [1.18, 1.19, 1.20]
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func verifyToken() error {
}

// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
Expand Down
14 changes: 7 additions & 7 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ func ExampleParseWithClaims_customClaimsType() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand All @@ -103,11 +103,11 @@ func ExampleParseWithClaims_validationOptions() {
jwt.RegisteredClaims
}

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand Down Expand Up @@ -136,11 +136,11 @@ func (m MyCustomClaims) CustomValidation() error {
func ExampleParseWithClaims_customValidation() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
}, jwt.WithLeeway(5*time.Second))

if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
Expand All @@ -154,7 +154,7 @@ func ExampleParse_errorChecking() {
// Token from another example. This token is expired
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"

token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
return []byte("AllYourBase"), nil
})

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/golang-jwt/jwt/v5

go 1.16
go 1.18
4 changes: 2 additions & 2 deletions hmac_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func ExampleParse_hmac() {
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
// to the callback, providing flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
Expand All @@ -57,7 +57,7 @@ func ExampleParse_hmac() {
return hmacSampleSecret, nil
})

if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
if claims := token.Claims; token.Valid {
fmt.Println(claims["foo"], claims["nbf"])
} else {
fmt.Println(err)
Expand Down
14 changes: 5 additions & 9 deletions http_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,14 @@ func Example_getTokenViaHTTP() {
tokenString := strings.TrimSpace(buf.String())

// Parse the token
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
})
fatal(err)

claims := token.Claims.(*CustomClaimsExample)
fmt.Println(claims.CustomerInfo.Name)
fmt.Println(token.Claims.CustomerInfo.Name)

//Output: test
}
Expand Down Expand Up @@ -138,18 +137,15 @@ func Example_useTokenViaHTTP() {

func createToken(user string) (string, error) {
// create a signer for rsa 256
t := jwt.New(jwt.GetSigningMethod("RS256"))

// set our claims
t.Claims = &CustomClaimsExample{
t := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), &CustomClaimsExample{
jwt.RegisteredClaims{
// set the expire time
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
},
"level1",
CustomerInfo{user, "human"},
}
})

// Creat token string
return t.SignedString(signKey)
Expand Down Expand Up @@ -192,7 +188,7 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
// only accessible with a valid token
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
// Get token from request
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[jwt.Claims]) (interface{}, error) {
// since we only use the one private key to sign the tokens,
// we also only use its public counter part to verify
return verifyKey, nil
Expand Down
67 changes: 60 additions & 7 deletions map_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestVerifyAud(t *testing.T) {
opts = append(opts, WithAudience(test.Comparison))
}

validator := newValidator(opts...)
validator := newValidator[MapClaims](opts...)
got := validator.Validate(test.MapClaims)

if (got == nil) != test.Expected {
Expand All @@ -77,7 +77,7 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
"iat": "foo",
}
want := false
got := newValidator(WithIssuedAt()).Validate(mapClaims)
got := newValidator[MapClaims](WithIssuedAt()).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -88,7 +88,7 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
"nbf": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := newValidator[MapClaims]().Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
Expand All @@ -99,7 +99,7 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
"exp": "foo",
}
want := false
got := newValidator().Validate(mapClaims)
got := newValidator[MapClaims]().Validate(mapClaims)

if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
Expand All @@ -112,25 +112,78 @@ func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
"exp": float64(exp.Unix()),
}
want := false
got := newValidator(WithTimeFunc(func() time.Time {
got := newValidator[MapClaims](WithTimeFunc(func() time.Time {
return exp
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

got = newValidator(WithTimeFunc(func() time.Time {
got = newValidator[MapClaims](WithTimeFunc(func() time.Time {
return exp.Add(1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}

want = true
got = newValidator(WithTimeFunc(func() time.Time {
got = newValidator[MapClaims](WithTimeFunc(func() time.Time {
return exp.Add(-1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}

func TestMapClaims_ParseString(t *testing.T) {
type args struct {
key string
}
tests := []struct {
name string
m MapClaims
args args
want string
wantErr bool
}{
{
name: "missing key",
m: MapClaims{},
args: args{
key: "mykey",
},
want: "",
wantErr: false,
},
{
name: "wrong key type",
m: MapClaims{"mykey": 4},
args: args{
key: "mykey",
},
want: "",
wantErr: true,
},
{
name: "correct key type",
m: MapClaims{"mykey": "mystring"},
args: args{
key: "mykey",
},
want: "mystring",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.ParseString(tt.args.key)
if (err != nil) != tt.wantErr {
t.Errorf("MapClaims.ParseString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("MapClaims.ParseString() = %v, want %v", got, tt.want)
}
})
}
}
22 changes: 11 additions & 11 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

type Parser struct {
type Parser[T Claims] struct {
// If populated, only these methods will be considered valid.
validMethods []string

Expand All @@ -21,24 +21,24 @@ type Parser struct {
}

// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{
func NewParser[T Claims](options ...ParserOption) *Parser[T] {
p := &Parser[T]{
validator: &validator{},
}

// Loop through our parsing options and apply them
for _, option := range options {
option(p)
option((*Parser[Claims])(p))
}

return p
}

// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
/*func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
}*/

// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
Expand All @@ -47,7 +47,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
func (p *Parser[T]) ParseWithClaims(tokenString string, claims T, keyFunc Keyfunc[T]) (*Token[T], error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
Expand Down Expand Up @@ -89,7 +89,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
if !p.skipClaimsValidation {
// Make sure we have at least a default validator
if p.validator == nil {
p.validator = newValidator()
p.validator = newValidator[T]()
}

if err := p.validator.Validate(claims); err != nil {
Expand Down Expand Up @@ -124,13 +124,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string, claims T) (token *Token[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}

token = &Token{Raw: tokenString}
token = &Token[T]{Raw: tokenString}

// parse Header
var headerBytes []byte
Expand All @@ -156,7 +156,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
dec.UseNumber()
}
// JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
if c, ok := any(token.Claims).(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
Expand Down
Loading