-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathjws.py
473 lines (387 loc) · 14.9 KB
/
jws.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
"""JSON Web Token"""
import json
import logging
from cryptojwt.jws.exception import JWSException
from ..exception import BadSignature
from ..exception import UnknownAlgorithm
from ..exception import WrongNumberOfParts
from ..jwk.asym import AsymmetricKey
from ..jwx import JWx
from ..simple_jwt import SimpleJWT
from ..utils import b64d_enc_dec
from ..utils import b64e_enc_dec
from ..utils import b64encode_item
from .dsa import ECDSASigner
from .exception import FormatError
from .exception import NoSuitableSigningKeys
from .exception import SignerAlgError
from .hmac import HMACSigner
from .pss import PSSSigner
from .rsa import RSASigner
from .utils import alg2keytype
try:
from builtins import object
from builtins import str
except ImportError:
pass
logger = logging.getLogger(__name__)
KDESC = ["use", "kid", "kty"]
SIGNER_ALGS = {
"HS256": HMACSigner("SHA256"),
"HS384": HMACSigner("SHA384"),
"HS512": HMACSigner("SHA512"),
"RS256": RSASigner("RS256"),
"RS384": RSASigner("RS384"),
"RS512": RSASigner("RS512"),
"ES256": ECDSASigner("ES256"),
"ES384": ECDSASigner("ES384"),
"ES512": ECDSASigner("ES512"),
"PS256": PSSSigner("SHA256"),
"PS384": PSSSigner("SHA384"),
"PS512": PSSSigner("SHA512"),
"none": None,
}
class JWSig(SimpleJWT):
def sign_input(self):
return self.b64part[0] + b"." + self.b64part[1]
def signature(self):
return self.part[2]
def __len__(self):
return len(self.part)
def valid(self):
if len(self) != 3:
return False
return True
class JWS(JWx):
def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
JWx.__init__(self, msg, with_digest, httpc, **kwargs)
if "alg" not in self:
self["alg"] = "RS256"
self._protected_headers = {}
def alg_keys(self, keys, use, protected=None):
_alg = self._pick_alg(keys)
if keys:
keys = self.pick_keys(keys, use=use, alg=_alg)
else:
keys = self.pick_keys(self._get_keys(), use=use, alg=_alg)
xargs = protected or {}
xargs["alg"] = _alg
if keys:
key = keys[0]
if key.kid:
xargs["kid"] = key.kid
elif not _alg or _alg.lower() == "none":
key = None
else:
if "kid" in self:
raise NoSuitableSigningKeys(
"No key for algorithm: %s and kid: %s" % (_alg, self["kid"])
)
else:
raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
return key, xargs, _alg
def sign_compact(self, keys=None, protected=None, **kwargs):
"""
Produce a JWS using the JWS Compact Serialization
:param keys: A dictionary of keys
:param protected: The protected headers (a dictionary)
:param kwargs: claims you want to add to the standard headers
:return: A signed JSON Web Token
"""
_headers = self._header
_headers.update(kwargs)
key, xargs, _alg = self.alg_keys(keys, "sig", protected)
if "typ" in self:
xargs["typ"] = self["typ"]
_headers.update(xargs)
jwt = JWSig(**_headers)
if _alg == "none":
return jwt.pack(parts=[self.msg, ""])
# All other cases
try:
_signer = SIGNER_ALGS[_alg]
except KeyError:
raise UnknownAlgorithm(_alg)
_input = jwt.pack(parts=[self.msg])
if isinstance(key, AsymmetricKey):
sig = _signer.sign(_input.encode("utf-8"), key.private_key())
else:
sig = _signer.sign(_input.encode("utf-8"), key.key)
logger.debug("Signed message using key with kid=%s" % key.kid)
return ".".join([_input, b64encode_item(sig).decode("utf-8")])
def verify_compact(self, jws=None, keys=None, allow_none=False, sigalg=None):
"""
Verify a JWT signature
:param jws: A signed JSON Web Token
:param keys: A list of keys that can possibly be used to verify the
signature
:param allow_none: If signature algorithm 'none' is allowed
:param sigalg: Expected sigalg
:return: Dictionary with 2 keys 'msg' required, 'key' optional
"""
return self.verify_compact_verbose(jws, keys, allow_none, sigalg)["msg"]
def verify_compact_verbose(self, jws=None, keys=None, allow_none=False, sigalg=None):
"""
Verify a JWT signature and return dict with validation results
:param jws: A signed JSON Web Token
:param keys: A list of keys that can possibly be used to verify the
signature
:param allow_none: If signature algorithm 'none' is allowed
:param sigalg: Expected sigalg
:return: Dictionary with 2 keys 'msg' required, 'key' optional.
The value of 'msg' is the unpacked and verified message.
The value of 'key' is the key used to verify the message
"""
if jws:
jwt = JWSig().unpack(jws)
if len(jwt) != 3:
raise WrongNumberOfParts(len(jwt))
self.jwt = jwt
elif not self.jwt:
raise ValueError("Missing signed JWT")
else:
jwt = self.jwt
try:
_alg = jwt.headers["alg"]
except KeyError:
_alg = None
else:
if _alg is None or _alg.lower() == "none":
if allow_none:
self.msg = jwt.payload()
return {"msg": self.msg}
else:
raise SignerAlgError("none not allowed")
if "alg" in self and self["alg"] and _alg:
if isinstance(self["alg"], list):
if _alg not in self["alg"]:
raise SignerAlgError(
"Wrong signing algorithm, expected {} got {}".format(self["alg"], _alg)
)
elif _alg != self["alg"]:
raise SignerAlgError(
"Wrong signing algorithm, expected {} got {}".format(self["alg"], _alg)
)
if sigalg and sigalg != _alg:
raise SignerAlgError("Expected {0} got {1}".format(sigalg, jwt.headers["alg"]))
self["alg"] = _alg
if keys:
_keys = self.pick_keys(keys)
else:
_keys = self.pick_keys(self._get_keys())
if not _keys:
if "kid" in self:
raise NoSuitableSigningKeys("No key with kid: %s" % (self["kid"]))
elif "kid" in self.jwt.headers:
raise NoSuitableSigningKeys("No key with kid: %s" % (self.jwt.headers["kid"]))
else:
raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg)
verifier = SIGNER_ALGS[_alg]
for key in _keys:
if isinstance(key, AsymmetricKey):
_key = key.public_key()
else:
_key = key.key
try:
if not verifier.verify(jwt.sign_input(), jwt.signature(), _key):
continue
except (BadSignature, IndexError):
pass
except (ValueError, TypeError) as err:
logger.warning('Exception "{}" caught'.format(err))
else:
logger.debug("Verified message using key with kid=%s" % key.kid)
self.msg = jwt.payload()
self.key = key
self._protected_headers = jwt.headers.copy()
return {"msg": self.msg, "key": key}
raise BadSignature()
def sign_json(self, keys=None, headers=None, flatten=False):
"""
Produce JWS using the JWS JSON Serialization
:param keys: list of keys to use for signing the JWS
:param headers: list of tuples (protected headers, unprotected
headers) for each signature
:return: A signed message using the JSON serialization format.
"""
def create_signature(protected, unprotected):
protected_headers = protected or {}
# always protect the signing alg header
protected_headers.setdefault("alg", self.alg)
_jws = JWS(self.msg, **protected_headers)
encoded_header, payload, signature = _jws.sign_compact(
protected=protected, keys=keys
).split(".")
signature_entry = {"signature": signature}
if unprotected:
signature_entry["header"] = unprotected
if encoded_header:
signature_entry["protected"] = encoded_header
return signature_entry
res = {"payload": b64e_enc_dec(self.msg, "utf-8", "ascii")}
if headers is None:
headers = [(dict(alg=self.alg), None)]
if flatten and len(headers) == 1: # Flattened JWS JSON Serialization Syntax
signature_entry = create_signature(*headers[0])
res.update(signature_entry)
else:
res["signatures"] = []
for protected, unprotected in headers:
signature_entry = create_signature(protected, unprotected)
res["signatures"].append(signature_entry)
return json.dumps(res)
def verify_json(self, jws, keys=None, allow_none=False, at_least_one=False):
"""
Verifies a JSON serialized signed JWT. The object may contain multiple
signatures. In the case that the verifier does not have the whole
set of necessary keys she may chose to accept that some verifications
fails due to no suitable key.
:param jws: The JSON document representing the signed JSON
:param keys: Keys that might be useful for verifying the signatures
:param allow_none: Allow the None signature algorithm. Is the same
as allowing no signature at all.
:param at_least_one: At least one of the signatures must verify
correctly. No suitable signing key is the only allowed exception.
:return:
"""
_jwss = json.loads(jws)
try:
_payload = _jwss["payload"]
except KeyError:
raise FormatError("Missing payload")
try:
_signs = _jwss["signatures"]
except KeyError:
# handle Flattened JWKS Serialization Syntax
signature = {}
for key in ["protected", "header", "signature"]:
if key in _jwss:
signature[key] = _jwss[key]
_signs = [signature]
_claim = None
_all_protected = {}
for _sign in _signs:
protected_headers = _sign.get("protected", "")
token = b".".join(
[
protected_headers.encode(),
_payload.encode(),
_sign["signature"].encode(),
]
)
unprotected_headers = _sign.get("header", {})
all_headers = unprotected_headers.copy()
if protected_headers:
_protected = json.loads(b64d_enc_dec(protected_headers))
_all_protected.update(_protected)
all_headers.update(_protected)
self.__init__(**all_headers)
try:
_tmp = self.verify_compact(token, keys, allow_none)
except NoSuitableSigningKeys:
if at_least_one is True:
logger.warning(
"Could not verify signature with headers: {}".format(all_headers)
)
continue
else:
raise
except JWSException as err:
raise
if _claim is None:
_claim = _tmp
else:
if _claim != _tmp:
raise ValueError()
if not _claim:
raise NoSuitableSigningKeys("None")
self._protected_headers = _all_protected
return _claim
def is_jws(self, jws):
"""
:param jws:
:return:
"""
try:
# JWS JSON serialization
try:
json_jws = json.loads(jws)
except TypeError:
jws = jws.decode("utf8")
json_jws = json.loads(jws)
return self._is_json_serialized_jws(json_jws)
except ValueError:
return self._is_compact_jws(jws)
def _is_json_serialized_jws(self, json_jws):
"""
Check if we've got a JSON serialized signed JWT.
:param json_jws: The message
:return: True/False
"""
json_ser_keys = {"payload", "signatures"}
flattened_json_ser_keys = {"payload", "signature"}
if not json_ser_keys.issubset(json_jws.keys()) and not flattened_json_ser_keys.issubset(
json_jws.keys()
):
return False
return True
def _is_compact_jws(self, jws):
"""
Check if we've got a compact signed JWT
:param jws: The message
:return: True/False
"""
try:
jwt = JWSig().unpack(jws)
except Exception as err:
logger.warning("Could not parse JWS: {}".format(err))
return False
if "alg" not in jwt.headers:
return False
if jwt.headers["alg"] is None:
jwt.headers["alg"] = "none"
if jwt.headers["alg"] not in SIGNER_ALGS:
logger.debug("UnknownSignerAlg: %s" % jwt.headers["alg"])
return False
self.jwt = jwt
return True
def alg2keytype(self, alg):
"""
Translate a signing algorithm into a specific key type.
:param alg: The signing algorithm
:return: A key type or None if there is no key type matching the
algorithm
"""
return alg2keytype(alg)
def set_header_claim(self, key, value):
"""
Set a specific claim in the header to a specific value.
:param key: The name of the claim
:param value: The value of the claim
"""
self._header[key] = value
def verify_alg(self, alg):
"""
Specifically check that the 'alg' claim has a specific value
:param alg: The expected alg value
:return: True if the alg value in the header is the same as the one
given. Returns False if no 'alg' claim exists in the header.
"""
try:
return self.jwt.verify_header("alg", alg)
except KeyError:
return False
def protected_headers(self):
return self._protected_headers.copy()
def factory(token, alg=""):
"""
Instantiate an JWS instance if the token is a signed JWT.
:param token: The token that might be a signed JWT
:param alg: The expected signature algorithm
:return: A JWS instance if the token was a signed JWT, otherwise None
"""
_jw = JWS(alg=alg)
if _jw.is_jws(token):
return _jw
else:
return None