Skip to content

Commit ca00a1f

Browse files
committed
Allow for multiple audiences
1 parent bc8bdca commit ca00a1f

File tree

3 files changed

+196
-0
lines changed

3 files changed

+196
-0
lines changed

map_claims_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,79 @@ func TestVerifyAud(t *testing.T) {
7272
}
7373
}
7474

75+
func TestVerifyAuds(t *testing.T) {
76+
var nilInterface interface{}
77+
var nilListInterface []interface{}
78+
var intListInterface interface{} = []int{1, 2, 3}
79+
80+
type test struct {
81+
Name string
82+
MapClaims MapClaims // MapClaims to validate
83+
Expected bool // Whether the validation is expected to pass
84+
Comparison []string // Cmp audience values
85+
86+
AllAudMatching bool // Whether to require all auds matching all cmps
87+
Required bool // Whether the aud claim is required
88+
}
89+
90+
tests := []test{
91+
// Matching auds and cmps
92+
{Name: "[]String aud with all expected cmps required and match required", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: true, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
93+
{Name: "[]String aud with any expected cmps required and match required", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: true, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: false},
94+
95+
// match single expected auds
96+
{Name: "[]String aud with any expected cmps required and match not required, single claim aud", MapClaims: MapClaims{"aud": []string{"example.com"}}, Expected: true, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: false},
97+
{Name: "[]String aud with any expected cmps required and match not required, single expected aud ", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: true, Required: true, Comparison: []string{"example.com"}, AllAudMatching: false},
98+
99+
// Non-matching auds and cmps
100+
// Required = true
101+
{Name: "[]String aud with all expected cmps required and match not required, single claim aud", MapClaims: MapClaims{"aud": []string{"example.com"}}, Expected: false, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
102+
{Name: "[]String aud with all expected cmps required and match not required, single expected aud ", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: []string{"example.com"}, AllAudMatching: true},
103+
{Name: "[]String aud with all expected cmps required and match not required, different auds", MapClaims: MapClaims{"aud": []string{"example.example.com"}}, Expected: false, Required: true, Comparison: []string{"example.com"}, AllAudMatching: true},
104+
105+
{Name: "[]String aud with any expected cmps required and match not required, different auds", MapClaims: MapClaims{"aud": []string{"example.example.com"}}, Expected: false, Required: true, Comparison: []string{"example.com"}, AllAudMatching: false},
106+
107+
// Required = false
108+
{Name: "[]String aud with all expected cmps required and match not required, single claim aud", MapClaims: MapClaims{"aud": []string{"example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
109+
{Name: "[]String aud with all expected cmps required and match not required, single expected aud ", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com"}, AllAudMatching: true},
110+
{Name: "[]String aud with all expected cmps required and match not required, different auds", MapClaims: MapClaims{"aud": []string{"example.example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com"}, AllAudMatching: true},
111+
112+
{Name: "[]String aud with any expected cmps required and match not required, single claim aud", MapClaims: MapClaims{"aud": []string{"example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: false},
113+
{Name: "[]String aud with any expected cmps required and match not required, single expected aud ", MapClaims: MapClaims{"aud": []string{"example.com", "example.example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com"}, AllAudMatching: false},
114+
{Name: "[]String aud with any expected cmps required and match not required, different auds", MapClaims: MapClaims{"aud": []string{"example.example.com"}}, Expected: true, Required: false, Comparison: []string{"example.com"}, AllAudMatching: false},
115+
116+
// Empty aud
117+
{Name: "Empty aud, with all expected cmps required", MapClaims: MapClaims{"aud": []string{}}, Expected: false, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
118+
{Name: "Empty aud, with any expected cmps required", MapClaims: MapClaims{"aud": []string{}}, Expected: false, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: false},
119+
120+
// []interface{}
121+
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
122+
{Name: "[]interface{} Aud with match required", MapClaims: MapClaims{"aud": []interface{}{"a", "foo", "example.com"}}, Expected: true, Required: true, Comparison: []string{"a", "foo", "example.com"}, AllAudMatching: true},
123+
{Name: "[]interface{} Aud with match but invalid types", MapClaims: MapClaims{"aud": []interface{}{"a", 5, "example.com"}}, Expected: false, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
124+
{Name: "[]interface{} Aud int with match required", MapClaims: MapClaims{"aud": intListInterface}, Expected: false, Required: true, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
125+
126+
// interface{}
127+
{Name: "Empty interface{} Aud without match not required", MapClaims: MapClaims{"aud": nilInterface}, Expected: true, Required: false, Comparison: []string{"example.com", "example.example.com"}, AllAudMatching: true},
128+
}
129+
130+
for _, test := range tests {
131+
t.Run(test.Name, func(t *testing.T) {
132+
var opts []ParserOption
133+
134+
if test.Required {
135+
opts = append(opts, WithAudiences(test.Comparison, test.AllAudMatching))
136+
}
137+
138+
validator := NewValidator(opts...)
139+
got := validator.Validate(test.MapClaims)
140+
141+
if (got == nil) != test.Expected {
142+
t.Errorf("Expected %v, got %v", test.Expected, (got == nil))
143+
}
144+
})
145+
}
146+
}
147+
75148
func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
76149
mapClaims := MapClaims{
77150
"iat": "foo",

parser_option.go

+13
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ func WithAudience(aud string) ParserOption {
8080
}
8181
}
8282

83+
// WithAudiences configures the validator to require the specified audiences in
84+
// the `auds` claim. Validation will fail if the audience is not listed in the
85+
// token or the `aud` claim is missing.
86+
//
87+
// matchAll is a boolean flag that determines if all expected audiences must be present in the token.
88+
// If matchAll is true, the token must contain all expected audiences. If matchAll is false, the token must contain at least one of the expected audiences.
89+
func WithAudiences(auds []string, matchAll bool) ParserOption {
90+
return func(p *Parser) {
91+
p.validator.expectedAuds = auds
92+
p.validator.expectedAudsMatchAll = matchAll
93+
}
94+
}
95+
8396
// WithIssuer configures the validator to require the specified issuer in the
8497
// `iss` claim. Validation will fail if a different issuer is specified in the
8598
// token or the `iss` claim is missing.

validator.go

+110
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ type Validator struct {
5555
// string will disable aud checking.
5656
expectedAud string
5757

58+
//expectedAuds contains the audiences this token expects. Supplying an empty
59+
// []string will disable auds checking.
60+
expectedAuds []string
61+
62+
// expectedAudsMatchAll specifies whether all expected audiences must match all auds from claim
63+
expectedAudsMatchAll bool
64+
5865
// expectedIss contains the issuer this token expects. Supplying an empty
5966
// string will disable iss checking.
6067
expectedIss string
@@ -126,6 +133,13 @@ func (v *Validator) Validate(claims Claims) error {
126133
}
127134
}
128135

136+
// If we have expected audiences, we also require the audiences claim
137+
if len(v.expectedAuds) != 0 {
138+
if err := v.verifyAudiences(claims, v.expectedAuds, true, v.expectedAudsMatchAll); err != nil {
139+
errs = append(errs, err)
140+
}
141+
}
142+
129143
// If we have an expected issuer, we also require the issuer claim
130144
if v.expectedIss != "" {
131145
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
@@ -255,6 +269,102 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
255269
return errorIfFalse(result, ErrTokenInvalidAudience)
256270
}
257271

272+
// verifyAudiences compares the aud claim against cmps.
273+
// If matchAllAuds is true, all cmps must match a aud.
274+
// If matchAllAuds is false, at least one cmp must match a aud.
275+
//
276+
// If matchAllAuds is true and aud length does not match cmps length, an ErrTokenInvalidAudience error will be returned.
277+
// Note that this does not account for any duplicate aud or cmps
278+
//
279+
// If aud is not set or an empty list, it will succeed if the claim is not required,
280+
// otherwise ErrTokenRequiredClaimMissing will be returned.
281+
//
282+
// Additionally, if any error occurs while retrieving the claim, e.g., when its
283+
// the wrong type, an ErrTokenUnverifiable error will be returned.
284+
func (v *Validator) verifyAudiences(claims Claims, cmps []string, required bool, matchAllAuds bool) error {
285+
286+
aud, err := claims.GetAudience()
287+
if err != nil {
288+
return err
289+
}
290+
291+
if len(aud) == 0 {
292+
return errorIfRequired(required, "aud")
293+
}
294+
295+
var stringClaims string
296+
297+
// If matchAllAuds is true, check if all the cmps matches any of the aud
298+
if matchAllAuds {
299+
300+
// cmps and aud length should match if matchAllAuds is true
301+
// Note that this does not account for possible duplicates
302+
if len(cmps) != len(aud) {
303+
return errorIfFalse(false, ErrTokenInvalidAudience)
304+
}
305+
306+
// Check all cmps values
307+
for _, cmp := range cmps {
308+
matchFound := false
309+
for _, a := range aud {
310+
311+
// Perform constant time comparison
312+
result := subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0
313+
314+
stringClaims = stringClaims + a
315+
316+
// If a match is found, set matchFound to true and break out of inner aud loop and continue to next cmp
317+
if result {
318+
matchFound = true
319+
break
320+
}
321+
}
322+
323+
// If no match was found for the current cmp, return a ErrTokenInvalidAudience error
324+
if !matchFound {
325+
return ErrTokenInvalidAudience
326+
}
327+
}
328+
329+
} else {
330+
// if matchAllAuds is false, check if any of the cmps matches any of the aud
331+
332+
matchFound := false
333+
334+
// Label to break out of both loops if a match is found
335+
outer:
336+
337+
// Check all aud values
338+
for _, a := range aud {
339+
for _, cmp := range cmps {
340+
341+
// Perform constant time comparison
342+
result := subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0
343+
344+
stringClaims = stringClaims + a
345+
346+
// If a match is found, break out of both loops and finish comparison
347+
if result {
348+
matchFound = true
349+
break outer
350+
}
351+
}
352+
}
353+
354+
// If no match was found for any cmp, return an error
355+
if !matchFound {
356+
return errorIfFalse(false, ErrTokenInvalidAudience)
357+
}
358+
}
359+
360+
// case where "" is sent in one or many aud claims
361+
if stringClaims == "" {
362+
return errorIfRequired(required, "aud")
363+
}
364+
365+
return nil
366+
}
367+
258368
// verifyIssuer compares the iss claim in claims against cmp.
259369
//
260370
// If iss is not set, it will succeed if the claim is not required,

0 commit comments

Comments
 (0)