Skip to content

Commit ba1886d

Browse files
committed
feat!: add support for type parameter
1 parent 3a9ee81 commit ba1886d

13 files changed

+242
-231
lines changed

cmd/jwt/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func verifyToken() error {
128128
}
129129

130130
// Parse the token. Load the key from command line option
131-
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
131+
token, err := jwt.Parse(string(tokData), func(t *jwt.Token[jwt.MapClaims]) (interface{}, error) {
132132
if isNone() {
133133
return jwt.UnsafeAllowNoneSignatureType, nil
134134
}

example_test.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func ExampleNewWithClaims_registeredClaims() {
2525
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
2626
ss, err := token.SignedString(mySigningKey)
2727
fmt.Printf("%v %v", ss, err)
28-
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
28+
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
2929
}
3030

3131
// Example creating a token using a custom claims type. The RegisteredClaims is embedded
@@ -67,7 +67,7 @@ func ExampleNewWithClaims_customClaimsType() {
6767
ss, err := token.SignedString(mySigningKey)
6868
fmt.Printf("%v %v", ss, err)
6969

70-
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
70+
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
7171
}
7272

7373
// Example creating a token using a custom claims type. The RegisteredClaims is embedded
@@ -80,12 +80,12 @@ func ExampleParseWithClaims_customClaimsType() {
8080
jwt.RegisteredClaims
8181
}
8282

83-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
83+
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
8484
return []byte("AllYourBase"), nil
8585
})
8686

87-
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
88-
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
87+
if token.Valid {
88+
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
8989
} else {
9090
fmt.Println(err)
9191
}
@@ -103,12 +103,12 @@ func ExampleParseWithClaims_validationOptions() {
103103
jwt.RegisteredClaims
104104
}
105105

106-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
106+
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
107107
return []byte("AllYourBase"), nil
108108
}, jwt.WithLeeway(5*time.Second))
109109

110-
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
111-
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
110+
if token.Valid {
111+
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
112112
} else {
113113
fmt.Println(err)
114114
}
@@ -136,12 +136,12 @@ func (m MyCustomClaims) CustomValidation() error {
136136
func ExampleParseWithClaims_customValidation() {
137137
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"
138138

139-
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
139+
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*MyCustomClaims]) (interface{}, error) {
140140
return []byte("AllYourBase"), nil
141141
}, jwt.WithLeeway(5*time.Second))
142142

143-
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
144-
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
143+
if token.Valid {
144+
fmt.Printf("%v %v", token.Claims.Foo, token.Claims.Issuer)
145145
} else {
146146
fmt.Println(err)
147147
}
@@ -152,9 +152,9 @@ func ExampleParseWithClaims_customValidation() {
152152
// An example of parsing the error types using errors.Is.
153153
func ExampleParse_errorChecking() {
154154
// Token from another example. This token is expired
155-
var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
155+
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
156156

157-
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
157+
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
158158
return []byte("AllYourBase"), nil
159159
})
160160

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
module github.com/golang-jwt/jwt/v5
22

3-
go 1.16
3+
go 1.18

hmac_example_test.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func ExampleParse_hmac() {
4747
// useful if you use multiple keys for your application. The standard is to use 'kid' in the
4848
// head of the token to identify which key to use, but the parsed token (head and claims) is provided
4949
// to the callback, providing flexibility.
50-
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
50+
token, err := jwt.Parse(tokenString, func(token *jwt.Token[jwt.MapClaims]) (interface{}, error) {
5151
// Don't forget to validate the alg is what you expect:
5252
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
5353
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
@@ -56,12 +56,11 @@ func ExampleParse_hmac() {
5656
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
5757
return hmacSampleSecret, nil
5858
})
59-
60-
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
61-
fmt.Println(claims["foo"], claims["nbf"])
62-
} else {
63-
fmt.Println(err)
59+
if err != nil {
60+
panic(err)
6461
}
6562

63+
fmt.Println(token.Claims["foo"], token.Claims["nbf"])
64+
6665
// Output: bar 1.4444784e+09
6766
}

http_example_test.go

+17-19
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,20 @@ func Example_getTokenViaHTTP() {
9999
tokenString := strings.TrimSpace(buf.String())
100100

101101
// Parse the token
102-
token, err := jwt.ParseWithClaims(tokenString, &CustomClaimsExample{}, func(token *jwt.Token) (interface{}, error) {
102+
token, err := jwt.ParseWithClaims(tokenString, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
103103
// since we only use the one private key to sign the tokens,
104104
// we also only use its public counter part to verify
105105
return verifyKey, nil
106106
})
107107
fatal(err)
108108

109-
claims := token.Claims.(*CustomClaimsExample)
109+
claims := token.Claims
110110
fmt.Println(claims.CustomerInfo.Name)
111111

112-
//Output: test
112+
// Output: test
113113
}
114114

115115
func Example_useTokenViaHTTP() {
116-
117116
// Make a sample token
118117
// In a real world situation, this token will have been acquired from
119118
// some other API call (see Example_getTokenViaHTTP)
@@ -138,18 +137,18 @@ func Example_useTokenViaHTTP() {
138137

139138
func createToken(user string) (string, error) {
140139
// create a signer for rsa 256
141-
t := jwt.New(jwt.GetSigningMethod("RS256"))
142-
143-
// set our claims
144-
t.Claims = &CustomClaimsExample{
145-
jwt.RegisteredClaims{
146-
// set the expire time
147-
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
148-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
140+
t := jwt.NewWithClaims(
141+
jwt.GetSigningMethod("RS256"),
142+
&CustomClaimsExample{
143+
jwt.RegisteredClaims{
144+
// set the expire time
145+
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
146+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
147+
},
148+
"level1",
149+
CustomerInfo{user, "human"},
149150
},
150-
"level1",
151-
CustomerInfo{user, "human"},
152-
}
151+
)
153152

154153
// Creat token string
155154
return t.SignedString(signKey)
@@ -192,12 +191,11 @@ func authHandler(w http.ResponseWriter, r *http.Request) {
192191
// only accessible with a valid token
193192
func restrictedHandler(w http.ResponseWriter, r *http.Request) {
194193
// Get token from request
195-
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
194+
token, err := request.ParseFromRequest(r, request.OAuth2Extractor, func(token *jwt.Token[*CustomClaimsExample]) (interface{}, error) {
196195
// since we only use the one private key to sign the tokens,
197196
// we also only use its public counter part to verify
198197
return verifyKey, nil
199-
}, request.WithClaims(&CustomClaimsExample{}))
200-
198+
})
201199
// If the token is missing or invalid, return error
202200
if err != nil {
203201
w.WriteHeader(http.StatusUnauthorized)
@@ -206,5 +204,5 @@ func restrictedHandler(w http.ResponseWriter, r *http.Request) {
206204
}
207205

208206
// Token is valid
209-
fmt.Fprintln(w, "Welcome,", token.Claims.(*CustomClaimsExample).Name)
207+
fmt.Fprintln(w, "Welcome,", token.Claims.Name)
210208
}

parser.go

+38-34
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"strings"
88
)
99

10-
type Parser struct {
10+
type parserOpts struct {
1111
// If populated, only these methods will be considered valid.
1212
validMethods []string
1313

@@ -20,44 +20,54 @@ type Parser struct {
2020
validator *validator
2121
}
2222

23+
type Parser[T Claims] struct {
24+
opts parserOpts
25+
}
26+
2327
// NewParser creates a new Parser with the specified options
24-
func NewParser(options ...ParserOption) *Parser {
25-
p := &Parser{
26-
validator: &validator{},
28+
func NewParser(options ...ParserOption) *Parser[MapClaims] {
29+
p := &Parser[MapClaims]{
30+
opts: parserOpts{validator: &validator{}},
2731
}
2832

2933
// Loop through our parsing options and apply them
3034
for _, option := range options {
31-
option(p)
35+
option(&p.opts)
3236
}
3337

3438
return p
3539
}
3640

37-
// Parse parses, validates, verifies the signature and returns the parsed token.
38-
// keyFunc will receive the parsed token and should return the key for validating.
39-
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
40-
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
41+
func NewParserFor[T Claims](options ...ParserOption) *Parser[T] {
42+
p := &Parser[T]{
43+
opts: parserOpts{validator: &validator{}},
44+
}
45+
46+
// Loop through our parsing options and apply them
47+
for _, option := range options {
48+
option(&p.opts)
49+
}
50+
51+
return p
4152
}
4253

43-
// ParseWithClaims parses, validates, and verifies like Parse, but supplies a default object implementing the Claims
44-
// interface. This provides default values which can be overridden and allows a caller to use their own type, rather
45-
// than the default MapClaims implementation of Claims.
54+
// Parse parses, validates, verifies the signature and returns the parsed token.
55+
// keyFunc will receive the parsed token and should return the key for validating.
4656
//
4757
// Note: If you provide a custom claim implementation that embeds one of the standard claims (such as RegisteredClaims),
4858
// make sure that a) you either embed a non-pointer version of the claims or b) if you are using a pointer, allocate the
4959
// proper memory for it before passing in the overall claims, otherwise you might run into a panic.
50-
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
51-
token, parts, err := p.ParseUnverified(tokenString, claims)
60+
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
61+
token, parts, err := p.ParseUnverified(tokenString)
5262
if err != nil {
5363
return token, err
5464
}
5565

5666
// Verify signing method is in the required set
57-
if p.validMethods != nil {
58-
var signingMethodValid = false
59-
var alg = token.Method.Alg()
60-
for _, m := range p.validMethods {
67+
if p.opts.validMethods != nil {
68+
signingMethodValid := false
69+
alg := token.Method.Alg()
70+
for _, m := range p.opts.validMethods {
6171
if m == alg {
6272
signingMethodValid = true
6373
break
@@ -86,13 +96,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
8696
vErr := &ValidationError{}
8797

8898
// Validate Claims
89-
if !p.skipClaimsValidation {
99+
if !p.opts.skipClaimsValidation {
90100
// Make sure we have at least a default validator
91-
if p.validator == nil {
92-
p.validator = newValidator()
101+
if p.opts.validator == nil {
102+
p.opts.validator = newValidator()
93103
}
94104

95-
if err := p.validator.Validate(claims); err != nil {
105+
if err := p.opts.validator.Validate(token.Claims); err != nil {
96106
// If the Claims Valid returned an error, check if it is a validation error,
97107
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
98108
if e, ok := err.(*ValidationError); !ok {
@@ -124,13 +134,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
124134
//
125135
// It's only ever useful in cases where you know the signature is valid (because it has
126136
// been checked previously in the stack) and you want to extract values from it.
127-
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
137+
func (p *Parser[T]) ParseUnverified(tokenString string) (token *Token[T], parts []string, err error) {
128138
parts = strings.Split(tokenString, ".")
129139
if len(parts) != 3 {
130140
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
131141
}
132142

133-
token = &Token{Raw: tokenString}
143+
token = &Token[T]{Raw: tokenString}
134144

135145
// parse Header
136146
var headerBytes []byte
@@ -146,23 +156,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
146156

147157
// parse Claims
148158
var claimBytes []byte
149-
token.Claims = claims
150-
151159
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
152160
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
153161
}
162+
154163
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
155-
if p.useJSONNumber {
164+
if p.opts.useJSONNumber {
156165
dec.UseNumber()
157166
}
158-
// JSON Decode. Special case for map type to avoid weird pointer behavior
159-
if c, ok := token.Claims.(MapClaims); ok {
160-
err = dec.Decode(&c)
161-
} else {
162-
err = dec.Decode(&claims)
163-
}
167+
164168
// Handle decode error
165-
if err != nil {
169+
if err = dec.Decode(&token.Claims); err != nil {
166170
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
167171
}
168172

0 commit comments

Comments
 (0)