1
1
"""Implementation of a Key Bundle."""
2
+
2
3
import copy
3
4
import json
4
5
import logging
5
6
import os
7
+ import threading
6
8
import time
7
9
from datetime import datetime
8
10
from functools import cmp_to_key
9
11
from typing import List
10
12
from typing import Optional
11
13
12
14
import requests
13
- from readerwriterlock import rwlock
14
15
15
16
from cryptojwt .jwk .ec import NIST2SEC
16
17
from cryptojwt .jwk .hmac import new_sym_key
@@ -152,14 +153,6 @@ def ec_init(spec):
152
153
return _kb
153
154
154
155
155
- def keys_reader (func ):
156
- def wrapper (self , * args , ** kwargs ):
157
- with self ._lock_reader :
158
- return func (self , * args , ** kwargs )
159
-
160
- return wrapper
161
-
162
-
163
156
def keys_writer (func ):
164
157
def wrapper (self , * args , ** kwargs ):
165
158
with self ._lock_writer :
@@ -245,9 +238,7 @@ def __init__(
245
238
self .source = None
246
239
self .time_out = 0
247
240
248
- self ._lock = rwlock .RWLockFairD ()
249
- self ._lock_reader = self ._lock .gen_rlock ()
250
- self ._lock_writer = self ._lock .gen_wlock ()
241
+ self ._lock_writer = threading .Lock ()
251
242
252
243
if httpc :
253
244
self .httpc = httpc
@@ -260,11 +251,11 @@ def __init__(
260
251
self .source = None
261
252
if isinstance (keys , dict ):
262
253
if "keys" in keys :
263
- self ._do_keys (keys ["keys" ])
254
+ self ._add_jwk_dicts (keys ["keys" ])
264
255
else :
265
- self ._do_keys ([keys ])
256
+ self ._add_jwk_dicts ([keys ])
266
257
else :
267
- self ._do_keys (keys )
258
+ self ._add_jwk_dicts (keys )
268
259
else :
269
260
self ._set_source (source , fileformat )
270
261
if self .local :
@@ -305,18 +296,34 @@ def _local_update_required(self) -> bool:
305
296
self .last_local = stat .st_mtime
306
297
return True
307
298
308
- @keys_writer
309
299
def do_keys (self , keys ):
310
- return self ._do_keys (keys )
300
+ """Compatibility function for add_jwk_dicts()"""
301
+ self .add_jwk_dicts (keys )
311
302
312
- def _do_keys (self , keys ):
303
+ @keys_writer
304
+ def add_jwk_dicts (self , keys ):
313
305
"""
314
- Go from JWK description to binary keys
306
+ Add JWK dictionaries
315
307
316
- :param keys:
308
+ :param keys: List of JWK dictionaries
317
309
:return:
318
310
"""
319
- _new_key = []
311
+ self ._add_jwk_dicts (keys )
312
+
313
+ def _add_jwk_dicts (self , keys ):
314
+ _new_keys = self .jwk_dicts_as_keys (keys )
315
+ if _new_keys :
316
+ self ._keys .extend (_new_keys )
317
+ self .last_updated = time .time ()
318
+
319
+ def jwk_dicts_as_keys (self , keys ):
320
+ """
321
+ Return JWK dictionaries as list of JWK objects
322
+
323
+ :param keys: List of JWK dictionaries
324
+ :return: List of JWK objects
325
+ """
326
+ _new_keys = []
320
327
321
328
for inst in keys :
322
329
if inst ["kty" ].lower () in K2C :
@@ -360,16 +367,13 @@ def _do_keys(self, keys):
360
367
if _key not in self ._keys :
361
368
if not _key .kid :
362
369
_key .add_kid ()
363
- _new_key .append (_key )
370
+ _new_keys .append (_key )
364
371
_error = ""
365
372
366
373
if _error :
367
374
LOGGER .warning ("While loading keys, %s" , _error )
368
375
369
- if _new_key :
370
- self ._keys .extend (_new_key )
371
-
372
- self .last_updated = time .time ()
376
+ return _new_keys
373
377
374
378
def _do_local_jwk (self , filename ):
375
379
"""
@@ -385,9 +389,9 @@ def _do_local_jwk(self, filename):
385
389
with open (filename ) as input_file :
386
390
_info = json .load (input_file )
387
391
if "keys" in _info :
388
- self ._do_keys (_info ["keys" ])
392
+ self ._add_jwk_dicts (_info ["keys" ])
389
393
else :
390
- self ._do_keys ([_info ])
394
+ self ._add_jwk_dicts ([_info ])
391
395
self .last_local = time .time ()
392
396
self .time_out = self .last_local + self .cache_time
393
397
return True
@@ -423,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
423
427
if kid :
424
428
key_args ["kid" ] = kid
425
429
426
- self ._do_keys ([key_args ])
430
+ self ._add_jwk_dicts ([key_args ])
427
431
self .last_local = time .time ()
428
432
self .time_out = self .last_local + self .cache_time
429
433
return True
430
434
431
- def do_remote (self ):
435
+ def _do_remote (self ):
432
436
"""
433
437
Load a JWKS from a webpage.
434
438
@@ -458,6 +462,7 @@ def do_remote(self):
458
462
LOGGER .error (err )
459
463
raise UpdateFailed (REMOTE_FAILED .format (self .source , str (err )))
460
464
465
+ new_keys = None
461
466
load_successful = _http_resp .status_code == 200
462
467
not_modified = _http_resp .status_code == 304
463
468
@@ -470,7 +475,7 @@ def do_remote(self):
470
475
471
476
LOGGER .debug ("Loaded JWKS: %s from %s" , _http_resp .text , self .source )
472
477
try :
473
- self ._do_keys (self .imp_jwks ["keys" ])
478
+ new_keys = self .jwk_dicts_as_keys (self .imp_jwks ["keys" ])
474
479
except KeyError :
475
480
LOGGER .error ("No 'keys' keyword in JWKS" )
476
481
self .ignore_errors_until = time .time () + self .ignore_errors_period
@@ -491,6 +496,8 @@ def do_remote(self):
491
496
self .ignore_errors_until = time .time () + self .ignore_errors_period
492
497
raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
493
498
499
+ if new_keys is not None :
500
+ self ._keys = new_keys
494
501
self .last_updated = time .time ()
495
502
self .ignore_errors_until = None
496
503
return load_successful
@@ -547,7 +554,7 @@ def update(self):
547
554
elif self .fileformat == "der" :
548
555
updated = self ._do_local_der (self .source , self .keytype , self .keyusage )
549
556
elif self .remote :
550
- updated = self .do_remote ()
557
+ updated = self ._do_remote ()
551
558
except Exception as err :
552
559
LOGGER .error ("Key bundle update failed: %s" , err )
553
560
self ._keys = _old_keys # restore
@@ -575,12 +582,11 @@ def get(self, typ="", only_active=True):
575
582
"""
576
583
self ._uptodate ()
577
584
578
- with self ._lock_reader :
579
- if typ :
580
- _typs = [typ .lower (), typ .upper ()]
581
- _keys = [k for k in self ._keys if k .kty in _typs ]
582
- else :
583
- _keys = copy .copy (self ._keys )
585
+ if typ :
586
+ _typs = [typ .lower (), typ .upper ()]
587
+ _keys = [k for k in self ._keys [:] if k .kty in _typs ]
588
+ else :
589
+ _keys = self ._keys [:]
584
590
585
591
if only_active :
586
592
return [k for k in _keys if not k .inactive_since ]
@@ -595,8 +601,7 @@ def keys(self, update: bool = True):
595
601
"""
596
602
if update :
597
603
self ._uptodate ()
598
- with self ._lock_reader :
599
- return copy .copy (self ._keys )
604
+ return self ._keys [:]
600
605
601
606
def active_keys (self ):
602
607
"""Return the set of active keys."""
@@ -668,7 +673,6 @@ def remove(self, key):
668
673
except ValueError :
669
674
pass
670
675
671
- @keys_reader
672
676
def __len__ (self ):
673
677
"""
674
678
The number of keys.
@@ -690,18 +694,12 @@ def get_key_with_kid(self, kid):
690
694
:return: The key or None
691
695
"""
692
696
self ._uptodate ()
693
- with self ._lock_reader :
694
- return self ._get_key_with_kid (kid )
697
+ return self ._get_key_with_kid (kid )
695
698
696
699
def _get_key_with_kid (self , kid ):
697
700
for key in self ._keys :
698
701
if key .kid == kid :
699
702
return key
700
-
701
- for key in self ._keys :
702
- if key .kid == kid :
703
- return key
704
-
705
703
return None
706
704
707
705
def kids (self ):
@@ -723,9 +721,7 @@ def mark_as_inactive(self, kid):
723
721
"""
724
722
k = self ._get_key_with_kid (kid )
725
723
if k :
726
- self ._keys .remove (k )
727
724
k .inactive_since = time .time ()
728
- self ._keys .append (k )
729
725
return True
730
726
else :
731
727
return False
@@ -753,30 +749,18 @@ def remove_outdated(self, after, when=0):
753
749
before it should be removed.
754
750
:param when: To make it easier to test
755
751
"""
756
- if when :
757
- now = when
758
- else :
759
- now = time .time ()
752
+ now = when or time .time ()
760
753
761
754
if not isinstance (after , float ):
762
755
after = float (after )
763
756
764
- _kl = []
765
- changed = False
766
- for k in self ._keys :
767
- if k .inactive_since and k .inactive_since + after < now :
768
- changed = True
769
- continue
770
-
771
- _kl .append (k )
772
-
773
- self ._keys = _kl
774
- return changed
757
+ self ._keys = [
758
+ k for k in self ._keys if not k .inactive_since or k .inactive_since + after > now
759
+ ]
775
760
776
761
def __contains__ (self , key ):
777
762
return key in self .keys ()
778
763
779
- @keys_reader
780
764
def copy (self ):
781
765
"""
782
766
Make deep copy of this KeyBundle
@@ -846,7 +830,7 @@ def load(self, spec):
846
830
"""
847
831
_keys = spec .get ("keys" , [])
848
832
if _keys :
849
- self ._do_keys (_keys )
833
+ self ._add_jwk_dicts (_keys )
850
834
851
835
for attr , default in self .params .items ():
852
836
val = spec .get (attr )
0 commit comments