Skip to content

Commit 104cec7

Browse files
committed
Moving DecodeSegement to Parser
This would allow us to remove some global variables and move them to parser options as well as potentially introduce interfaces for json and b64 encoding/decoding to replace the std lib, if someone wanted to do that for performance reasons. We keep the functions exported because of explicit user demand.
1 parent 148d710 commit 104cec7

17 files changed

+70
-92
lines changed

ecdsa.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {
5555

5656
// Verify implements token verification for the SigningMethod.
5757
// For this verify method, key must be an ecdsa.PublicKey struct
58-
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
59-
var err error
60-
61-
// Decode the signature
62-
var sig []byte
63-
if sig, err = DecodeSegment(signature); err != nil {
64-
return err
65-
}
66-
58+
func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interface{}) error {
6759
// Get the key
6860
var ecdsaKey *ecdsa.PublicKey
6961
switch k := key.(type) {

ecdsa_test.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestECDSAVerify(t *testing.T) {
6565
parts := strings.Split(data.tokenString, ".")
6666

6767
method := jwt.GetSigningMethod(data.alg)
68-
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey)
68+
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ecdsaKey)
6969
if data.valid && err != nil {
7070
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
7171
}
@@ -98,7 +98,7 @@ func TestECDSASign(t *testing.T) {
9898
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
9999
}
100100

101-
err = method.Verify(toSign, sig, ecdsaKey.Public())
101+
err = method.Verify(toSign, decodeSegment(t, sig), ecdsaKey.Public())
102102
if err != nil {
103103
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
104104
}
@@ -162,3 +162,13 @@ func BenchmarkECDSASigning(b *testing.B) {
162162
})
163163
}
164164
}
165+
166+
func decodeSegment(t *testing.T, signature string) (sig []byte) {
167+
var err error
168+
sig, err = jwt.NewParser().DecodeSegment(signature)
169+
if err != nil {
170+
t.Fatalf("could not decode segment: %v", err)
171+
}
172+
173+
return
174+
}

ed25519.go

+1-8
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ func (m *SigningMethodEd25519) Alg() string {
3434

3535
// Verify implements token verification for the SigningMethod.
3636
// For this verify method, key must be an ed25519.PublicKey
37-
func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error {
38-
var err error
37+
func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key interface{}) error {
3938
var ed25519Key ed25519.PublicKey
4039
var ok bool
4140

@@ -47,12 +46,6 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter
4746
return ErrInvalidKey
4847
}
4948

50-
// Decode the signature
51-
var sig []byte
52-
if sig, err = DecodeSegment(signature); err != nil {
53-
return err
54-
}
55-
5649
// Verify the signature
5750
if !ed25519.Verify(ed25519Key, []byte(signingString), sig) {
5851
return ErrEd25519Verification

ed25519_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestEd25519Verify(t *testing.T) {
4949

5050
method := jwt.GetSigningMethod(data.alg)
5151

52-
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key)
52+
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ed25519Key)
5353
if data.valid && err != nil {
5454
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
5555
}

hmac.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,13 @@ func (m *SigningMethodHMAC) Alg() string {
4646
}
4747

4848
// Verify implements token verification for the SigningMethod. Returns nil if the signature is valid.
49-
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
49+
func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error {
5050
// Verify the key is the right type
5151
keyBytes, ok := key.([]byte)
5252
if !ok {
5353
return ErrInvalidKeyType
5454
}
5555

56-
// Decode signature, for comparison
57-
sig, err := DecodeSegment(signature)
58-
if err != nil {
59-
return err
60-
}
61-
6256
// Can we use the specified hashing method?
6357
if !m.Hash.Available() {
6458
return ErrHashUnavailable

hmac_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestHMACVerify(t *testing.T) {
5353
parts := strings.Split(data.tokenString, ".")
5454

5555
method := jwt.GetSigningMethod(data.alg)
56-
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey)
56+
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), hmacTestKey)
5757
if data.valid && err != nil {
5858
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
5959
}

none.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ func (m *signingMethodNone) Alg() string {
2525
}
2626

2727
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
28-
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
28+
func (m *signingMethodNone) Verify(signingString string, sig []byte, key interface{}) (err error) {
2929
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
3030
// accepting 'none' signing method
3131
if _, ok := key.(unsafeNoneMagicConstant); !ok {
3232
return NoneSignatureTypeDisallowedError
3333
}
3434
// If signing method is none, signature must be an empty string
35-
if signature != "" {
35+
if string(sig) != "" {
3636
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
3737
}
3838

none_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestNoneVerify(t *testing.T) {
4646
parts := strings.Split(data.tokenString, ".")
4747

4848
method := jwt.GetSigningMethod(data.alg)
49-
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], data.key)
49+
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), data.key)
5050
if data.valid && err != nil {
5151
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
5252
}

parser.go

+29-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jwt
22

33
import (
44
"bytes"
5+
"encoding/base64"
56
"encoding/json"
67
"fmt"
78
"strings"
@@ -79,8 +80,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
7980
return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
8081
}
8182

83+
// Decode signature
84+
token.Signature, err = p.DecodeSegment(parts[2])
85+
if err != nil {
86+
return token, newError("could not base64 decode signature", ErrTokenMalformed, err)
87+
}
88+
8289
// Perform signature validation
83-
token.Signature = parts[2]
8490
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
8591
return token, newError("", ErrTokenSignatureInvalid, err)
8692
}
@@ -119,7 +125,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
119125

120126
// parse Header
121127
var headerBytes []byte
122-
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
128+
if headerBytes, err = p.DecodeSegment(parts[0]); err != nil {
123129
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
124130
return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed)
125131
}
@@ -133,7 +139,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
133139
var claimBytes []byte
134140
token.Claims = claims
135141

136-
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
142+
if claimBytes, err = p.DecodeSegment(parts[1]); err != nil {
137143
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
138144
}
139145
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
@@ -162,3 +168,23 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
162168

163169
return token, parts, nil
164170
}
171+
172+
// DecodeSegment decodes a JWT specific base64url encoding with padding stripped
173+
//
174+
// Deprecated: In a future release, we will demote this function to a
175+
// non-exported function, since it should only be used internally
176+
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
177+
encoding := base64.RawURLEncoding
178+
179+
if DecodePaddingAllowed {
180+
if l := len(seg) % 4; l > 0 {
181+
seg += strings.Repeat("=", 4-l)
182+
}
183+
encoding = base64.URLEncoding
184+
}
185+
186+
if DecodeStrict {
187+
encoding = encoding.Strict()
188+
}
189+
return encoding.DecodeString(seg)
190+
}

parser_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ func TestParser_Parse(t *testing.T) {
415415
}
416416

417417
if data.valid {
418-
if token.Signature == "" {
418+
if len(token.Signature) == 0 {
419419
t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)
420420
}
421421
if !token.Valid {
@@ -473,7 +473,7 @@ func TestParser_ParseUnverified(t *testing.T) {
473473
// The 'Valid' field should not be set to true when invoking ParseUnverified()
474474
t.Errorf("[%v] Token.Valid field mismatch. Expecting false, got %v", data.name, token.Valid)
475475
}
476-
if token.Signature != "" {
476+
if len(token.Signature) != 0 {
477477
// The signature was not validated, hence the 'Signature' field is not populated.
478478
t.Errorf("[%v] Token.Signature field mismatch. Expecting '', got %v", data.name, token.Signature)
479479
}

rsa.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ func (m *SigningMethodRSA) Alg() string {
4646

4747
// Verify implements token verification for the SigningMethod
4848
// For this signing method, must be an *rsa.PublicKey structure.
49-
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
50-
var err error
51-
52-
// Decode the signature
53-
var sig []byte
54-
if sig, err = DecodeSegment(signature); err != nil {
55-
return err
56-
}
57-
49+
func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key interface{}) error {
5850
var rsaKey *rsa.PublicKey
5951
var ok bool
6052

rsa_pss.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,7 @@ func init() {
8282

8383
// Verify implements token verification for the SigningMethod.
8484
// For this verify method, key must be an rsa.PublicKey struct
85-
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
86-
var err error
87-
88-
// Decode the signature
89-
var sig []byte
90-
if sig, err = DecodeSegment(signature); err != nil {
91-
return err
92-
}
93-
85+
func (m *SigningMethodRSAPSS) Verify(signingString string, sig []byte, key interface{}) error {
9486
var rsaKey *rsa.PublicKey
9587
switch k := key.(type) {
9688
case *rsa.PublicKey:

rsa_pss_test.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func TestRSAPSSVerify(t *testing.T) {
6464
parts := strings.Split(data.tokenString, ".")
6565

6666
method := jwt.GetSigningMethod(data.alg)
67-
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], rsaPSSKey)
67+
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), rsaPSSKey)
6868
if data.valid && err != nil {
6969
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
7070
}
@@ -114,19 +114,19 @@ func TestRSAPSSSaltLengthCompatibility(t *testing.T) {
114114
SaltLength: rsa.PSSSaltLengthAuto,
115115
},
116116
}
117-
if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthEqualsHash)) {
117+
if !verify(t, jwt.SigningMethodPS256, makeToken(ps256SaltLengthEqualsHash)) {
118118
t.Error("SigningMethodPS256 should accept salt length that is defined in RFC")
119119
}
120-
if !verify(ps256SaltLengthEqualsHash, makeToken(jwt.SigningMethodPS256)) {
120+
if !verify(t, ps256SaltLengthEqualsHash, makeToken(jwt.SigningMethodPS256)) {
121121
t.Error("Sign by SigningMethodPS256 should have salt length that is defined in RFC")
122122
}
123-
if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthAuto)) {
123+
if !verify(t, jwt.SigningMethodPS256, makeToken(ps256SaltLengthAuto)) {
124124
t.Error("SigningMethodPS256 should accept auto salt length to be compatible with previous versions")
125125
}
126-
if !verify(ps256SaltLengthAuto, makeToken(jwt.SigningMethodPS256)) {
126+
if !verify(t, ps256SaltLengthAuto, makeToken(jwt.SigningMethodPS256)) {
127127
t.Error("Sign by SigningMethodPS256 should be accepted by previous versions")
128128
}
129-
if verify(ps256SaltLengthEqualsHash, makeToken(ps256SaltLengthAuto)) {
129+
if verify(t, ps256SaltLengthEqualsHash, makeToken(ps256SaltLengthAuto)) {
130130
t.Error("Auto salt length should be not accepted, when RFC salt length is required")
131131
}
132132
}
@@ -144,8 +144,8 @@ func makeToken(method jwt.SigningMethod) string {
144144
return signed
145145
}
146146

147-
func verify(signingMethod jwt.SigningMethod, token string) bool {
147+
func verify(t *testing.T, signingMethod jwt.SigningMethod, token string) bool {
148148
segments := strings.Split(token, ".")
149-
err := signingMethod.Verify(strings.Join(segments[:2], "."), segments[2], test.LoadRSAPublicKeyFromDisk("test/sample_key.pub"))
149+
err := signingMethod.Verify(strings.Join(segments[:2], "."), decodeSegment(t, segments[2]), test.LoadRSAPublicKeyFromDisk("test/sample_key.pub"))
150150
return err == nil
151151
}

rsa_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestRSAVerify(t *testing.T) {
4848
parts := strings.Split(data.tokenString, ".")
4949

5050
method := jwt.GetSigningMethod(data.alg)
51-
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
51+
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), key)
5252
if data.valid && err != nil {
5353
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
5454
}
@@ -85,7 +85,7 @@ func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) {
8585
}
8686
testData := rsaTestData[0]
8787
parts := strings.Split(testData.tokenString, ".")
88-
err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), parts[2], parsedKey)
88+
err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), parsedKey)
8989
if err != nil {
9090
t.Errorf("[%v] Error while verifying key: %v", testData.name, err)
9191
}

signing_method.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ var signingMethodLock = new(sync.RWMutex)
99

1010
// SigningMethod can be used add new methods for signing or verifying tokens.
1111
type SigningMethod interface {
12-
Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid
13-
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
14-
Alg() string // returns the alg identifier for this method (example: 'HS256')
12+
Verify(signingString string, sig []byte, key interface{}) error // Returns nil if signature is valid
13+
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
14+
Alg() string // returns the alg identifier for this method (example: 'HS256')
1515
}
1616

1717
// RegisterSigningMethod registers the "alg" name and a factory function for signing method.

token.go

+1-21
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type Token struct {
3737
Method SigningMethod // Method is the signing method used or to be used
3838
Header map[string]interface{} // Header is the first segment of the token
3939
Claims Claims // Claims is the second segment of the token
40-
Signature string // Signature is the third segment of the token. Populated when you Parse a token
40+
Signature []byte // Signature is the third segment of the token. Populated when you Parse a token
4141
Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token
4242
}
4343

@@ -123,23 +123,3 @@ func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options
123123
func EncodeSegment(seg []byte) string {
124124
return base64.RawURLEncoding.EncodeToString(seg)
125125
}
126-
127-
// DecodeSegment decodes a JWT specific base64url encoding with padding stripped
128-
//
129-
// Deprecated: In a future release, we will demote this function to a
130-
// non-exported function, since it should only be used internally
131-
func DecodeSegment(seg string) ([]byte, error) {
132-
encoding := base64.RawURLEncoding
133-
134-
if DecodePaddingAllowed {
135-
if l := len(seg) % 4; l > 0 {
136-
seg += strings.Repeat("=", 4-l)
137-
}
138-
encoding = base64.URLEncoding
139-
}
140-
141-
if DecodeStrict {
142-
encoding = encoding.Strict()
143-
}
144-
return encoding.DecodeString(seg)
145-
}

token_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func TestToken_SigningString(t1 *testing.T) {
1212
Method jwt.SigningMethod
1313
Header map[string]interface{}
1414
Claims jwt.Claims
15-
Signature string
15+
Signature []byte
1616
Valid bool
1717
}
1818
tests := []struct {
@@ -30,9 +30,8 @@ func TestToken_SigningString(t1 *testing.T) {
3030
"typ": "JWT",
3131
"alg": jwt.SigningMethodHS256.Alg(),
3232
},
33-
Claims: jwt.RegisteredClaims{},
34-
Signature: "",
35-
Valid: false,
33+
Claims: jwt.RegisteredClaims{},
34+
Valid: false,
3635
},
3736
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
3837
wantErr: false,

0 commit comments

Comments
 (0)