Skip to content

Commit 3318c89

Browse files
authored
Merge pull request #110 from jschlyter/locklesser
Locklesser
2 parents 366d889 + e997039 commit 3318c89

File tree

2 files changed

+64
-80
lines changed

2 files changed

+64
-80
lines changed

Diff for: src/cryptojwt/key_bundle.py

+51-67
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
"""Implementation of a Key Bundle."""
2+
23
import copy
34
import json
45
import logging
56
import os
7+
import threading
68
import time
79
from datetime import datetime
810
from functools import cmp_to_key
911
from typing import List
1012
from typing import Optional
1113

1214
import requests
13-
from readerwriterlock import rwlock
1415

1516
from cryptojwt.jwk.ec import NIST2SEC
1617
from cryptojwt.jwk.hmac import new_sym_key
@@ -152,14 +153,6 @@ def ec_init(spec):
152153
return _kb
153154

154155

155-
def keys_reader(func):
156-
def wrapper(self, *args, **kwargs):
157-
with self._lock_reader:
158-
return func(self, *args, **kwargs)
159-
160-
return wrapper
161-
162-
163156
def keys_writer(func):
164157
def wrapper(self, *args, **kwargs):
165158
with self._lock_writer:
@@ -245,9 +238,7 @@ def __init__(
245238
self.source = None
246239
self.time_out = 0
247240

248-
self._lock = rwlock.RWLockFairD()
249-
self._lock_reader = self._lock.gen_rlock()
250-
self._lock_writer = self._lock.gen_wlock()
241+
self._lock_writer = threading.Lock()
251242

252243
if httpc:
253244
self.httpc = httpc
@@ -260,11 +251,11 @@ def __init__(
260251
self.source = None
261252
if isinstance(keys, dict):
262253
if "keys" in keys:
263-
self._do_keys(keys["keys"])
254+
self._add_jwk_dicts(keys["keys"])
264255
else:
265-
self._do_keys([keys])
256+
self._add_jwk_dicts([keys])
266257
else:
267-
self._do_keys(keys)
258+
self._add_jwk_dicts(keys)
268259
else:
269260
self._set_source(source, fileformat)
270261
if self.local:
@@ -305,18 +296,34 @@ def _local_update_required(self) -> bool:
305296
self.last_local = stat.st_mtime
306297
return True
307298

308-
@keys_writer
309299
def do_keys(self, keys):
310-
return self._do_keys(keys)
300+
"""Compatibility function for add_jwk_dicts()"""
301+
self.add_jwk_dicts(keys)
311302

312-
def _do_keys(self, keys):
303+
@keys_writer
304+
def add_jwk_dicts(self, keys):
313305
"""
314-
Go from JWK description to binary keys
306+
Add JWK dictionaries
315307
316-
:param keys:
308+
:param keys: List of JWK dictionaries
317309
:return:
318310
"""
319-
_new_key = []
311+
self._add_jwk_dicts(keys)
312+
313+
def _add_jwk_dicts(self, keys):
314+
_new_keys = self.jwk_dicts_as_keys(keys)
315+
if _new_keys:
316+
self._keys.extend(_new_keys)
317+
self.last_updated = time.time()
318+
319+
def jwk_dicts_as_keys(self, keys):
320+
"""
321+
Return JWK dictionaries as list of JWK objects
322+
323+
:param keys: List of JWK dictionaries
324+
:return: List of JWK objects
325+
"""
326+
_new_keys = []
320327

321328
for inst in keys:
322329
if inst["kty"].lower() in K2C:
@@ -360,16 +367,13 @@ def _do_keys(self, keys):
360367
if _key not in self._keys:
361368
if not _key.kid:
362369
_key.add_kid()
363-
_new_key.append(_key)
370+
_new_keys.append(_key)
364371
_error = ""
365372

366373
if _error:
367374
LOGGER.warning("While loading keys, %s", _error)
368375

369-
if _new_key:
370-
self._keys.extend(_new_key)
371-
372-
self.last_updated = time.time()
376+
return _new_keys
373377

374378
def _do_local_jwk(self, filename):
375379
"""
@@ -385,9 +389,9 @@ def _do_local_jwk(self, filename):
385389
with open(filename) as input_file:
386390
_info = json.load(input_file)
387391
if "keys" in _info:
388-
self._do_keys(_info["keys"])
392+
self._add_jwk_dicts(_info["keys"])
389393
else:
390-
self._do_keys([_info])
394+
self._add_jwk_dicts([_info])
391395
self.last_local = time.time()
392396
self.time_out = self.last_local + self.cache_time
393397
return True
@@ -423,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
423427
if kid:
424428
key_args["kid"] = kid
425429

426-
self._do_keys([key_args])
430+
self._add_jwk_dicts([key_args])
427431
self.last_local = time.time()
428432
self.time_out = self.last_local + self.cache_time
429433
return True
430434

431-
def do_remote(self):
435+
def _do_remote(self):
432436
"""
433437
Load a JWKS from a webpage.
434438
@@ -458,6 +462,7 @@ def do_remote(self):
458462
LOGGER.error(err)
459463
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))
460464

465+
new_keys = None
461466
load_successful = _http_resp.status_code == 200
462467
not_modified = _http_resp.status_code == 304
463468

@@ -470,7 +475,7 @@ def do_remote(self):
470475

471476
LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source)
472477
try:
473-
self._do_keys(self.imp_jwks["keys"])
478+
new_keys = self.jwk_dicts_as_keys(self.imp_jwks["keys"])
474479
except KeyError:
475480
LOGGER.error("No 'keys' keyword in JWKS")
476481
self.ignore_errors_until = time.time() + self.ignore_errors_period
@@ -491,6 +496,8 @@ def do_remote(self):
491496
self.ignore_errors_until = time.time() + self.ignore_errors_period
492497
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))
493498

499+
if new_keys is not None:
500+
self._keys = new_keys
494501
self.last_updated = time.time()
495502
self.ignore_errors_until = None
496503
return load_successful
@@ -547,7 +554,7 @@ def update(self):
547554
elif self.fileformat == "der":
548555
updated = self._do_local_der(self.source, self.keytype, self.keyusage)
549556
elif self.remote:
550-
updated = self.do_remote()
557+
updated = self._do_remote()
551558
except Exception as err:
552559
LOGGER.error("Key bundle update failed: %s", err)
553560
self._keys = _old_keys # restore
@@ -575,12 +582,11 @@ def get(self, typ="", only_active=True):
575582
"""
576583
self._uptodate()
577584

578-
with self._lock_reader:
579-
if typ:
580-
_typs = [typ.lower(), typ.upper()]
581-
_keys = [k for k in self._keys if k.kty in _typs]
582-
else:
583-
_keys = copy.copy(self._keys)
585+
if typ:
586+
_typs = [typ.lower(), typ.upper()]
587+
_keys = [k for k in self._keys[:] if k.kty in _typs]
588+
else:
589+
_keys = self._keys[:]
584590

585591
if only_active:
586592
return [k for k in _keys if not k.inactive_since]
@@ -595,8 +601,7 @@ def keys(self, update: bool = True):
595601
"""
596602
if update:
597603
self._uptodate()
598-
with self._lock_reader:
599-
return copy.copy(self._keys)
604+
return self._keys[:]
600605

601606
def active_keys(self):
602607
"""Return the set of active keys."""
@@ -668,7 +673,6 @@ def remove(self, key):
668673
except ValueError:
669674
pass
670675

671-
@keys_reader
672676
def __len__(self):
673677
"""
674678
The number of keys.
@@ -690,18 +694,12 @@ def get_key_with_kid(self, kid):
690694
:return: The key or None
691695
"""
692696
self._uptodate()
693-
with self._lock_reader:
694-
return self._get_key_with_kid(kid)
697+
return self._get_key_with_kid(kid)
695698

696699
def _get_key_with_kid(self, kid):
697700
for key in self._keys:
698701
if key.kid == kid:
699702
return key
700-
701-
for key in self._keys:
702-
if key.kid == kid:
703-
return key
704-
705703
return None
706704

707705
def kids(self):
@@ -723,9 +721,7 @@ def mark_as_inactive(self, kid):
723721
"""
724722
k = self._get_key_with_kid(kid)
725723
if k:
726-
self._keys.remove(k)
727724
k.inactive_since = time.time()
728-
self._keys.append(k)
729725
return True
730726
else:
731727
return False
@@ -753,30 +749,18 @@ def remove_outdated(self, after, when=0):
753749
before it should be removed.
754750
:param when: To make it easier to test
755751
"""
756-
if when:
757-
now = when
758-
else:
759-
now = time.time()
752+
now = when or time.time()
760753

761754
if not isinstance(after, float):
762755
after = float(after)
763756

764-
_kl = []
765-
changed = False
766-
for k in self._keys:
767-
if k.inactive_since and k.inactive_since + after < now:
768-
changed = True
769-
continue
770-
771-
_kl.append(k)
772-
773-
self._keys = _kl
774-
return changed
757+
self._keys = [
758+
k for k in self._keys if not k.inactive_since or k.inactive_since + after > now
759+
]
775760

776761
def __contains__(self, key):
777762
return key in self.keys()
778763

779-
@keys_reader
780764
def copy(self):
781765
"""
782766
Make deep copy of this KeyBundle
@@ -846,7 +830,7 @@ def load(self, spec):
846830
"""
847831
_keys = spec.get("keys", [])
848832
if _keys:
849-
self._do_keys(_keys)
833+
self._add_jwk_dicts(_keys)
850834

851835
for attr, default in self.params.items():
852836
val = spec.get(attr)

0 commit comments

Comments
 (0)