Skip to content

Commit 6207357

Browse files
zhouyiheng.gooxisto
zhouyiheng.go
andcommitted
feat: custom json and base64 encoders for Token and Parser
Co-Authored-By: Christian Banse <[email protected]>
1 parent 1e76606 commit 6207357

7 files changed

+180
-17
lines changed

encoder.go

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package jwt
2+
3+
// Base64Encoder is an interface that allows to implement custom Base64 encoding
4+
// algorithms.
5+
type Base64EncodeFunc func(src []byte) string
6+
7+
// Base64Decoder is an interface that allows to implement custom Base64 decoding
8+
// algorithms.
9+
type Base64DecodeFunc func(s string) ([]byte, error)
10+
11+
// JSONEncoder is an interface that allows to implement custom JSON encoding
12+
// algorithms.
13+
type JSONMarshalFunc func(v any) ([]byte, error)
14+
15+
// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal
16+
// algorithms.
17+
type JSONUnmarshalFunc func(data []byte, v any) error

parser.go

+24-5
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@ type Parser struct {
1212
// If populated, only these methods will be considered valid.
1313
validMethods []string
1414

15-
// Use JSON Number format in JSON decoder.
15+
// Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder.
1616
useJSONNumber bool
1717

1818
// Skip claims validation during token parsing.
1919
skipClaimsValidation bool
2020

2121
validator *validator
2222

23+
// This field is disabled when using a custom base64 encoder.
2324
decodeStrict bool
2425

26+
// This field is disabled when using a custom base64 encoder.
2527
decodePaddingAllowed bool
28+
29+
unmarshalFunc JSONUnmarshalFunc
30+
base64DecodeFunc Base64DecodeFunc
2631
}
2732

2833
// NewParser creates a new Parser with the specified options
@@ -148,7 +153,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
148153
if headerBytes, err = p.DecodeSegment(parts[0]); err != nil {
149154
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
150155
}
151-
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
156+
157+
// Choose our JSON decoder. If no custom function is supplied, we use the standard library.
158+
var unmarshal JSONUnmarshalFunc
159+
if p.unmarshalFunc != nil {
160+
unmarshal = p.unmarshalFunc
161+
} else {
162+
unmarshal = json.Unmarshal
163+
}
164+
165+
err = unmarshal(headerBytes, &token.Header)
166+
if err != nil {
152167
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
153168
}
154169

@@ -162,13 +177,13 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
162177

163178
// If `useJSONNumber` is enabled then we must use *json.Decoder to decode
164179
// the claims. However, this comes with a performance penalty so only use
165-
// it if we must and, otherwise, simple use json.Unmarshal.
180+
// it if we must and, otherwise, simple use our decode function.
166181
if !p.useJSONNumber {
167182
// JSON Unmarshal. Special case for map type to avoid weird pointer behavior.
168183
if c, ok := token.Claims.(MapClaims); ok {
169-
err = json.Unmarshal(claimBytes, &c)
184+
err = unmarshal(claimBytes, &c)
170185
} else {
171-
err = json.Unmarshal(claimBytes, &claims)
186+
err = unmarshal(claimBytes, &claims)
172187
}
173188
} else {
174189
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
@@ -200,6 +215,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
200215
// take into account whether the [Parser] is configured with additional options,
201216
// such as [WithStrictDecoding] or [WithPaddingAllowed].
202217
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
218+
if p.base64DecodeFunc != nil {
219+
return p.base64DecodeFunc(seg)
220+
}
221+
203222
encoding := base64.RawURLEncoding
204223

205224
if p.decodePaddingAllowed {

parser_option.go

+14
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,17 @@ func WithStrictDecoding() ParserOption {
118118
p.decodeStrict = true
119119
}
120120
}
121+
122+
// WithJSONUnmarshal supports a custom [JSONUnmarshal] to use in parsing the JWT.
123+
func WithJSONUnmarshal(f JSONUnmarshalFunc) ParserOption {
124+
return func(p *Parser) {
125+
p.unmarshalFunc = f
126+
}
127+
}
128+
129+
// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT.
130+
func WithBase64Decoder(f Base64DecodeFunc) ParserOption {
131+
return func(p *Parser) {
132+
p.base64DecodeFunc = f
133+
}
134+
}

0 commit comments

Comments
 (0)